Commit 82fd7a8b authored by yan.yan's avatar yan.yan
Browse files

v2.1.5: add profile tool and python 3.6 for linux

parent f31eee3a
...@@ -23,6 +23,7 @@ from typing import List ...@@ -23,6 +23,7 @@ from typing import List
from cumm.conv.params import ConvProblem from cumm.conv.params import ConvProblem
import numpy as np import numpy as np
class Point2VoxelCommon(pccm.ParameterizedClass): class Point2VoxelCommon(pccm.ParameterizedClass):
def __init__(self, dtype: dtypes.DType, ndim: int, zyx: bool = True): def __init__(self, dtype: dtypes.DType, ndim: int, zyx: bool = True):
super().__init__() super().__init__()
...@@ -35,7 +36,6 @@ class Point2VoxelCommon(pccm.ParameterizedClass): ...@@ -35,7 +36,6 @@ class Point2VoxelCommon(pccm.ParameterizedClass):
retf2_str = f"std::array<float, {self.ndim * 2}>" retf2_str = f"std::array<float, {self.ndim * 2}>"
self.calc_meta_ret = f"std::tuple<{retf_str}, {ret_str}, {ret_str}, {retf2_str}>" self.calc_meta_ret = f"std::tuple<{retf_str}, {ret_str}, {ret_str}, {retf2_str}>"
@pccm.pybind.mark
@pccm.static_function @pccm.static_function
def calc_meta_data(self): def calc_meta_data(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
...@@ -80,7 +80,8 @@ class Point2VoxelCommon(pccm.ParameterizedClass): ...@@ -80,7 +80,8 @@ class Point2VoxelCommon(pccm.ParameterizedClass):
retf_str = f"std::array<float, {self.ndim}>" retf_str = f"std::array<float, {self.ndim}>"
retf2_str = f"std::array<float, {self.ndim * 2}>" retf2_str = f"std::array<float, {self.ndim * 2}>"
return code.ret(f"std::tuple<{retf_str}, {ret_str}, {ret_str}, {retf2_str}>") return code.ret(
f"std::tuple<{retf_str}, {ret_str}, {ret_str}, {retf2_str}>")
@pccm.static_function @pccm.static_function
def array2tvarray(self): def array2tvarray(self):
...@@ -112,11 +113,16 @@ class Point2VoxelCommon(pccm.ParameterizedClass): ...@@ -112,11 +113,16 @@ class Point2VoxelCommon(pccm.ParameterizedClass):
""") """)
return code.ret("std::array<T, N>") return code.ret("std::array<T, N>")
class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
"""this class don't support multi-thread. """this class don't support multi-thread.
create p2v for every thread. create p2v for every thread.
""" """
def __init__(self, dtype: dtypes.DType, ndim: int, layout: TensorGeneric, zyx: bool = True): def __init__(self,
dtype: dtypes.DType,
ndim: int,
layout: TensorGeneric,
zyx: bool = True):
super().__init__() super().__init__()
self.add_dependency(TensorView, TensorViewHashKernel) self.add_dependency(TensorView, TensorViewHashKernel)
self.add_param_class("layout_ns", layout, "Layout") self.add_param_class("layout_ns", layout, "Layout")
...@@ -278,6 +284,7 @@ class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -278,6 +284,7 @@ class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
""") """)
return code return code
class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
def __init__(self, dtype: dtypes.DType, ndim: int, zyx: bool = True): def __init__(self, dtype: dtypes.DType, ndim: int, zyx: bool = True):
super().__init__() super().__init__()
...@@ -289,11 +296,20 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -289,11 +296,20 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
self.dtype = dtype self.dtype = dtype
self.ndim = ndim self.ndim = ndim
self.zyx = zyx self.zyx = zyx
cuda_funcs = [self.point_to_voxel_hash, self.point_to_voxel_hash_static] cuda_funcs = [
self.add_impl_only_param_class(cuda_funcs, "kernel", Point2VoxelKernel(dtype, ndim, layout, zyx)) self.point_to_voxel_hash, self.point_to_voxel_hash_static
]
self.add_pybind_member("hashdata", "tv::Tensor", readwrite=False, pyanno="cumm.tensorview.Tensor") self.add_impl_only_param_class(
self.add_pybind_member("point_indice_data", "tv::Tensor", readwrite=False, pyanno="cumm.tensorview.Tensor") cuda_funcs, "kernel", Point2VoxelKernel(dtype, ndim, layout, zyx))
self.add_pybind_member("hashdata",
"tv::Tensor",
readwrite=False,
pyanno="cumm.tensorview.Tensor")
self.add_pybind_member("point_indice_data",
"tv::Tensor",
readwrite=False,
pyanno="cumm.tensorview.Tensor")
self.add_pybind_member("voxels", "tv::Tensor", readwrite=False) self.add_pybind_member("voxels", "tv::Tensor", readwrite=False)
self.add_pybind_member("indices", "tv::Tensor", readwrite=False) self.add_pybind_member("indices", "tv::Tensor", readwrite=False)
...@@ -439,13 +455,13 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -439,13 +455,13 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
""") """)
return code.ret("std::tuple<tv::Tensor, tv::Tensor, tv::Tensor>") return code.ret("std::tuple<tv::Tensor, tv::Tensor, tv::Tensor>")
@pccm.pybind.mark @pccm.pybind.mark
@pccm.cuda.static_function @pccm.cuda.static_function
def point_to_voxel_hash_static(self): def point_to_voxel_hash_static(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.arg("points", "tv::Tensor") code.arg("points", "tv::Tensor")
code.arg("voxels, indices, num_per_voxel, hashdata, point_indice_data", "tv::Tensor") code.arg("voxels, indices, num_per_voxel, hashdata, point_indice_data",
"tv::Tensor")
code.arg("vsize", f"std::array<float, {self.ndim}>") code.arg("vsize", f"std::array<float, {self.ndim}>")
code.arg("grid_size, grid_stride", f"std::array<int, {self.ndim}>") code.arg("grid_size, grid_stride", f"std::array<int, {self.ndim}>")
code.arg("coors_range", f"std::array<float, {self.ndim * 2}>") code.arg("coors_range", f"std::array<float, {self.ndim * 2}>")
...@@ -533,7 +549,10 @@ class Point2VoxelCPU(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -533,7 +549,10 @@ class Point2VoxelCPU(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
self.p2v_c = Point2VoxelCommon(dtype, ndim, zyx) self.p2v_c = Point2VoxelCommon(dtype, ndim, zyx)
self.add_param_class("p2v_c", self.p2v_c, "Point2VoxelCommon") self.add_param_class("p2v_c", self.p2v_c, "Point2VoxelCommon")
self.add_pybind_member("densehashdata", "tv::Tensor", readwrite=False, pyanno="cumm.tensorview.Tensor") self.add_pybind_member("densehashdata",
"tv::Tensor",
readwrite=False,
pyanno="cumm.tensorview.Tensor")
self.add_pybind_member("voxels", "tv::Tensor", readwrite=False) self.add_pybind_member("voxels", "tv::Tensor", readwrite=False)
self.add_pybind_member("indices", "tv::Tensor", readwrite=False) self.add_pybind_member("indices", "tv::Tensor", readwrite=False)
...@@ -568,7 +587,6 @@ class Point2VoxelCPU(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -568,7 +587,6 @@ class Point2VoxelCPU(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
""") """)
return code.ret(self.p2v_c.calc_meta_ret) return code.ret(self.p2v_c.calc_meta_ret)
@pccm.pybind.mark @pccm.pybind.mark
@pccm.constructor @pccm.constructor
def ctor(self): def ctor(self):
......
...@@ -4,13 +4,14 @@ from pathlib import Path ...@@ -4,13 +4,14 @@ from pathlib import Path
import numpy as np import numpy as np
import torch import torch
from spconv.pytorch import ops from spconv.pytorch import ops, functional
from spconv.pytorch.conv import (SparseConv2d, SparseConv3d, SparseConvTranspose2d, from spconv.pytorch.conv import (SparseConv2d, SparseConv3d,
SparseConvTranspose3d, SparseInverseConv2d, SparseConvTranspose2d, SparseConvTranspose3d,
SparseInverseConv3d, SubMConv2d, SubMConv3d) SparseInverseConv2d, SparseInverseConv3d,
SubMConv2d, SubMConv3d)
from spconv.pytorch.core import SparseConvTensor from spconv.pytorch.core import SparseConvTensor
from spconv.pytorch.identity import Identity from spconv.pytorch.identity import Identity
from spconv.pytorch.modules import SparseModule, SparseSequential from spconv.pytorch.modules import SparseModule, SparseSequential, assign_name_for_sparse_modules
from spconv.pytorch.ops import ConvAlgo from spconv.pytorch.ops import ConvAlgo
from spconv.pytorch.pool import SparseMaxPool2d, SparseMaxPool3d from spconv.pytorch.pool import SparseMaxPool2d, SparseMaxPool3d
from spconv.pytorch.tables import AddTable, ConcatTable, JoinTable from spconv.pytorch.tables import AddTable, ConcatTable, JoinTable
......
...@@ -24,12 +24,13 @@ from torch.nn.parameter import Parameter ...@@ -24,12 +24,13 @@ from torch.nn.parameter import Parameter
from spconv import pytorch as spconv from spconv import pytorch as spconv
from spconv.core import ConvAlgo from spconv.core import ConvAlgo
import spconv.pytorch.functional as Fsp from spconv.pytorch import functional as Fsp
from spconv.pytorch import ops from spconv.pytorch import ops
from spconv.cppconstants import CPU_ONLY_BUILD from spconv.cppconstants import CPU_ONLY_BUILD
from spconv.pytorch.core import IndiceData, SparseConvTensor, ImplicitGemmIndiceData from spconv.pytorch.core import IndiceData, SparseConvTensor, ImplicitGemmIndiceData
from spconv.pytorch.modules import SparseModule from spconv.pytorch.modules import SparseModule
from spconv.constants import FILTER_HWIO from spconv.constants import FILTER_HWIO
from spconv.utils import nullcontext
def _calculate_fan_in_and_fan_out_hwio(tensor, algo: ConvAlgo): def _calculate_fan_in_and_fan_out_hwio(tensor, algo: ConvAlgo):
...@@ -205,6 +206,7 @@ class SparseConvolution(SparseModule): ...@@ -205,6 +206,7 @@ class SparseConvolution(SparseModule):
self.dilation) self.dilation)
else: else:
out_spatial_shape = spatial_shape out_spatial_shape = spatial_shape
# print(self._sparse_unique_name, spatial_shape, out_spatial_shape)
# input.update_grid(out_spatial_shape) # input.update_grid(out_spatial_shape)
# t = time.time() # t = time.time()
out_tensor = input.shadow_copy() out_tensor = input.shadow_copy()
...@@ -249,12 +251,16 @@ class SparseConvolution(SparseModule): ...@@ -249,12 +251,16 @@ class SparseConvolution(SparseModule):
indice_dict = input.indice_dict.copy() indice_dict = input.indice_dict.copy()
algo = self.algo algo = self.algo
if self.indice_key is not None : if self.indice_key is not None:
datas = input.find_indice_pair(self.indice_key) datas = input.find_indice_pair(self.indice_key)
if datas is not None: if datas is not None:
msg = "due to limitation of pytorch, you must provide same algo to layers share same indice key." msg = "due to limitation of pytorch, you must provide same algo to layers share same indice key."
assert algo == datas.algo, msg assert algo == datas.algo, msg
# algo = datas.algo # algo = datas.algo
profile_ctx = nullcontext()
if input._timer is not None and self._sparse_unique_name:
profile_ctx = input._timer.namespace(self._sparse_unique_name)
with profile_ctx:
if algo == ConvAlgo.Native: if algo == ConvAlgo.Native:
datas = input.find_indice_pair(self.indice_key) datas = input.find_indice_pair(self.indice_key)
if datas is not None: if datas is not None:
...@@ -308,10 +314,9 @@ class SparseConvolution(SparseModule): ...@@ -308,10 +314,9 @@ class SparseConvolution(SparseModule):
if indice_pairs.device != features.device: if indice_pairs.device != features.device:
indice_pairs_calc = indice_pairs.to(features.device) indice_pairs_calc = indice_pairs.to(features.device)
if self.subm: if self.subm:
out_features = Fsp.indice_subm_conv(features, self.weight, out_features = Fsp.indice_subm_conv(
indice_pairs_calc, features, self.weight, indice_pairs_calc,
indice_pair_num, indice_pair_num, outids.shape[0], algo, input._timer)
outids.shape[0], algo)
else: else:
if self.inverse: if self.inverse:
out_features = Fsp.indice_inverse_conv( out_features = Fsp.indice_inverse_conv(
...@@ -321,7 +326,8 @@ class SparseConvolution(SparseModule): ...@@ -321,7 +326,8 @@ class SparseConvolution(SparseModule):
out_features = Fsp.indice_conv(features, self.weight, out_features = Fsp.indice_conv(features, self.weight,
indice_pairs_calc, indice_pairs_calc,
indice_pair_num, indice_pair_num,
outids.shape[0], algo) outids.shape[0], algo,
input._timer)
else: else:
datas = input.find_indice_pair(self.indice_key) datas = input.find_indice_pair(self.indice_key)
...@@ -350,6 +356,7 @@ class SparseConvolution(SparseModule): ...@@ -350,6 +356,7 @@ class SparseConvolution(SparseModule):
mask_argsort_bwd_splits = datas.mask_argsort_bwd_splits mask_argsort_bwd_splits = datas.mask_argsort_bwd_splits
masks = datas.masks masks = datas.masks
else: else:
with input._timer.namespace("gen_pairs"):
res = ops.get_indice_pairs_implicit_gemm( res = ops.get_indice_pairs_implicit_gemm(
indices, indices,
batch_size, batch_size,
...@@ -363,7 +370,8 @@ class SparseConvolution(SparseModule): ...@@ -363,7 +370,8 @@ class SparseConvolution(SparseModule):
subm=self.subm, subm=self.subm,
transpose=self.transposed, transpose=self.transposed,
is_train=self.training, is_train=self.training,
alloc=input.thrust_allocator) alloc=input.thrust_allocator,
timer=input._timer)
outids = res[0] outids = res[0]
num_inds_per_loc = res[1] num_inds_per_loc = res[1]
pair_fwd = res[2] pair_fwd = res[2]
...@@ -398,7 +406,8 @@ class SparseConvolution(SparseModule): ...@@ -398,7 +406,8 @@ class SparseConvolution(SparseModule):
features, self.weight, pair_fwd, pair_bwd, features, self.weight, pair_fwd, pair_bwd,
pair_mask_fwd_splits, pair_mask_bwd_splits, pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits, mask_argsort_fwd_splits, mask_argsort_bwd_splits,
num_activate_out, masks, self.training, self.subm) num_activate_out, masks, self.training, self.subm,
input._timer)
if self.bias is not None: if self.bias is not None:
out_features += self.bias out_features += self.bias
if input.benchmark: if input.benchmark:
......
...@@ -19,6 +19,7 @@ import torch ...@@ -19,6 +19,7 @@ import torch
from spconv.core import ConvAlgo from spconv.core import ConvAlgo
from spconv.pytorch.constants import PYTORCH_VERSION from spconv.pytorch.constants import PYTORCH_VERSION
from spconv.pytorch.ops import ThrustSortAllocator from spconv.pytorch.ops import ThrustSortAllocator
from spconv.tools import CUDAKernelTimer
if PYTORCH_VERSION >= [1, 8, 0]: if PYTORCH_VERSION >= [1, 8, 0]:
try: try:
...@@ -51,13 +52,14 @@ class IndiceData(object): ...@@ -51,13 +52,14 @@ class IndiceData(object):
class ImplicitGemmIndiceData(object): class ImplicitGemmIndiceData(object):
def __init__(self, out_indices: torch.Tensor, indices: torch.Tensor, pair_fwd: torch.Tensor, def __init__(self, out_indices: torch.Tensor, indices: torch.Tensor,
pair_bwd: torch.Tensor, pair_fwd: torch.Tensor, pair_bwd: torch.Tensor,
pair_mask_fwd_splits: List[torch.Tensor], pair_mask_fwd_splits: List[torch.Tensor],
pair_mask_bwd_splits: List[torch.Tensor], pair_mask_bwd_splits: List[torch.Tensor],
mask_argsort_fwd_splits: List[torch.Tensor], mask_argsort_fwd_splits: List[torch.Tensor],
mask_argsort_bwd_splits: List[torch.Tensor], mask_argsort_bwd_splits: List[torch.Tensor],
masks: List[np.ndarray], out_spatial_shape, is_subm: bool, algo: ConvAlgo): masks: List[np.ndarray], out_spatial_shape, is_subm: bool,
algo: ConvAlgo):
self.out_indices = out_indices self.out_indices = out_indices
self.indices = indices self.indices = indices
self.pair_fwd = pair_fwd self.pair_fwd = pair_fwd
...@@ -99,7 +101,8 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -99,7 +101,8 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
voxel_num: Optional[torch.Tensor] = None, voxel_num: Optional[torch.Tensor] = None,
indice_dict: Optional[dict] = None, indice_dict: Optional[dict] = None,
benchmark: bool = False, benchmark: bool = False,
permanent_thrust_allocator: bool = False): permanent_thrust_allocator: bool = False,
enable_timer: bool = False):
""" """
Args: Args:
features: [num_points, num_features] feature tensor features: [num_points, num_features] feature tensor
...@@ -133,6 +136,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -133,6 +136,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
self.thrust_allocator: Optional[ThrustSortAllocator] = None self.thrust_allocator: Optional[ThrustSortAllocator] = None
if permanent_thrust_allocator: if permanent_thrust_allocator:
self.thrust_allocator = ThrustSortAllocator(features.device) self.thrust_allocator = ThrustSortAllocator(features.device)
self._timer = CUDAKernelTimer(enable_timer)
def replace_feature(self, feature): def replace_feature(self, feature):
"""we need to replace x.features = F.relu(x.features) with x = x.replace_feature(F.relu(x.features)) """we need to replace x.features = F.relu(x.features) with x = x.replace_feature(F.relu(x.features))
...@@ -144,7 +148,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -144,7 +148,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
new_spt.benchmark = self.benchmark new_spt.benchmark = self.benchmark
new_spt.benchmark_record = self.benchmark_record new_spt.benchmark_record = self.benchmark_record
new_spt.thrust_allocator = self.thrust_allocator new_spt.thrust_allocator = self.thrust_allocator
new_spt._timer = self._timer
return new_spt return new_spt
@property @property
...@@ -174,7 +178,8 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -174,7 +178,8 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
def spatial_size(self): def spatial_size(self):
return np.prod(self.spatial_shape) return np.prod(self.spatial_shape)
def find_indice_pair(self, key) -> Optional[Union[IndiceData, ImplicitGemmIndiceData]]: def find_indice_pair(
self, key) -> Optional[Union[IndiceData, ImplicitGemmIndiceData]]:
if key is None: if key is None:
return None return None
if key in self.indice_dict: if key in self.indice_dict:
...@@ -208,4 +213,5 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -208,4 +213,5 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
self.benchmark) self.benchmark)
tensor.benchmark_record = self.benchmark_record tensor.benchmark_record = self.benchmark_record
tensor.thrust_allocator = self.thrust_allocator tensor.thrust_allocator = self.thrust_allocator
tensor._timer = self._timer
return tensor return tensor
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from cumm import tensorview as tv from cumm import tensorview as tv
import torch import torch
from typing import Optional, List from typing import Optional, List
_TORCH_DTYPE_TO_TV = { _TORCH_DTYPE_TO_TV = {
torch.float32: tv.float32, torch.float32: tv.float32,
torch.float64: tv.float64, torch.float64: tv.float64,
...@@ -26,7 +27,10 @@ _TORCH_DTYPE_TO_TV = { ...@@ -26,7 +27,10 @@ _TORCH_DTYPE_TO_TV = {
torch.uint8: tv.uint8, torch.uint8: tv.uint8,
} }
def torch_tensor_to_tv(ten: torch.Tensor, dtype: Optional[int] = None, shape: Optional[List[int]] = None):
def torch_tensor_to_tv(ten: torch.Tensor,
dtype: Optional[int] = None,
shape: Optional[List[int]] = None):
assert ten.is_contiguous(), "must be contiguous tensor" assert ten.is_contiguous(), "must be contiguous tensor"
ptr = ten.data_ptr() ptr = ten.data_ptr()
device = ten.device device = ten.device
...@@ -42,9 +46,11 @@ def torch_tensor_to_tv(ten: torch.Tensor, dtype: Optional[int] = None, shape: Op ...@@ -42,9 +46,11 @@ def torch_tensor_to_tv(ten: torch.Tensor, dtype: Optional[int] = None, shape: Op
dtype = _TORCH_DTYPE_TO_TV[ten.dtype] dtype = _TORCH_DTYPE_TO_TV[ten.dtype]
return tv.from_blob(ptr, shape, dtype, tv_device) return tv.from_blob(ptr, shape, dtype, tv_device)
def get_current_stream(): def get_current_stream():
return torch.cuda.current_stream().cuda_stream return torch.cuda.current_stream().cuda_stream
if __name__ == "__main__": if __name__ == "__main__":
a = torch.rand(2, 2) a = torch.rand(2, 2)
atv = torch_tensor_to_tv(a) atv = torch_tensor_to_tv(a)
......
...@@ -15,8 +15,9 @@ ...@@ -15,8 +15,9 @@
import torch import torch
from torch import nn from torch import nn
from torch.autograd import Function from torch.autograd import Function
from typing import Optional
import spconv.pytorch.ops as ops from spconv.tools import CUDAKernelTimer
from spconv.pytorch import ops
import torch.cuda.amp as amp import torch.cuda.amp as amp
from torch.autograd.function import once_differentiable from torch.autograd.function import once_differentiable
import numpy as np import numpy as np
...@@ -27,23 +28,32 @@ from typing import List ...@@ -27,23 +28,32 @@ from typing import List
class SparseConvFunction(Function): class SparseConvFunction(Function):
@staticmethod @staticmethod
@amp.custom_fwd(cast_inputs=torch.float16) @amp.custom_fwd(cast_inputs=torch.float16)
def forward(ctx, features, filters, indice_pairs, indice_pair_num, def forward(ctx,
num_activate_out, algo): features,
filters,
indice_pairs,
indice_pair_num,
num_activate_out,
algo,
timer: CUDAKernelTimer = CUDAKernelTimer(False)):
ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters) ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters)
ctx.algo = algo ctx.algo = algo
ctx.timer = timer
return ops.indice_conv(features, return ops.indice_conv(features,
filters, filters,
indice_pairs, indice_pairs,
indice_pair_num, indice_pair_num,
num_activate_out, num_activate_out,
False, False,
algo=algo) algo=algo,
timer=timer)
@staticmethod @staticmethod
@once_differentiable @once_differentiable
@amp.custom_bwd @amp.custom_bwd
def backward(ctx, grad_output): def backward(ctx, grad_output):
indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
timer = ctx.timer
input_bp, filters_bp = ops.indice_conv_backward(features, input_bp, filters_bp = ops.indice_conv_backward(features,
filters, filters,
...@@ -51,18 +61,27 @@ class SparseConvFunction(Function): ...@@ -51,18 +61,27 @@ class SparseConvFunction(Function):
indice_pairs, indice_pairs,
indice_pair_num, indice_pair_num,
False, False,
algo=ctx.algo) algo=ctx.algo,
timer=timer)
return input_bp, filters_bp, None, None, None, None return input_bp, filters_bp, None, None, None, None, None
class SparseInverseConvFunction(Function): class SparseInverseConvFunction(Function):
@staticmethod @staticmethod
@amp.custom_fwd(cast_inputs=torch.float16) @amp.custom_fwd(cast_inputs=torch.float16)
def forward(ctx, features, filters, indice_pairs, indice_pair_num, def forward(ctx,
num_activate_out, algo): features,
filters,
indice_pairs,
indice_pair_num,
num_activate_out,
algo,
timer: CUDAKernelTimer = CUDAKernelTimer(False)):
ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters) ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters)
ctx.algo = algo ctx.algo = algo
ctx.timer = timer
return ops.indice_conv(features, return ops.indice_conv(features,
filters, filters,
indice_pairs, indice_pairs,
...@@ -70,13 +89,16 @@ class SparseInverseConvFunction(Function): ...@@ -70,13 +89,16 @@ class SparseInverseConvFunction(Function):
num_activate_out, num_activate_out,
True, True,
False, False,
algo=algo) algo=algo,
timer=timer)
@staticmethod @staticmethod
@once_differentiable @once_differentiable
@amp.custom_bwd @amp.custom_bwd
def backward(ctx, grad_output): def backward(ctx, grad_output):
indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
timer = ctx.timer
input_bp, filters_bp = ops.indice_conv_backward(features, input_bp, filters_bp = ops.indice_conv_backward(features,
filters, filters,
grad_output, grad_output,
...@@ -84,29 +106,40 @@ class SparseInverseConvFunction(Function): ...@@ -84,29 +106,40 @@ class SparseInverseConvFunction(Function):
indice_pair_num, indice_pair_num,
True, True,
False, False,
algo=ctx.algo) algo=ctx.algo,
timer=timer)
return input_bp, filters_bp, None, None, None, None return input_bp, filters_bp, None, None, None, None, None
class SparseImplicitGemmFunction(Function): class SparseImplicitGemmFunction(Function):
@staticmethod @staticmethod
@amp.custom_fwd(cast_inputs=torch.float16) @amp.custom_fwd(cast_inputs=torch.float16)
def forward(ctx, features: torch.Tensor, filters: torch.Tensor, def forward(ctx,
pair_fwd: torch.Tensor, pair_bwd: torch.Tensor, features: torch.Tensor,
filters: torch.Tensor,
pair_fwd: torch.Tensor,
pair_bwd: torch.Tensor,
pair_mask_fwd_splits: List[torch.Tensor], pair_mask_fwd_splits: List[torch.Tensor],
pair_mask_bwd_splits: List[torch.Tensor], pair_mask_bwd_splits: List[torch.Tensor],
mask_argsort_fwd_splits: List[torch.Tensor], mask_argsort_fwd_splits: List[torch.Tensor],
mask_argsort_bwd_splits: List[torch.Tensor], mask_argsort_bwd_splits: List[torch.Tensor],
num_activate_out: int, masks: List[np.ndarray], is_train: bool, num_activate_out: int,
is_subm: bool): masks: List[np.ndarray],
is_train: bool,
is_subm: bool,
timer: CUDAKernelTimer = CUDAKernelTimer(False)):
out, mask_out, mask_width = ops.implicit_gemm( out, mask_out, mask_width = ops.implicit_gemm(features, filters,
features, filters, pair_fwd, pair_mask_fwd_splits, pair_fwd,
mask_argsort_fwd_splits, num_activate_out, masks, is_train, is_subm) pair_mask_fwd_splits,
mask_argsort_fwd_splits,
num_activate_out, masks,
is_train, is_subm, timer)
ctx.save_for_backward(features, filters, pair_fwd, pair_bwd) ctx.save_for_backward(features, filters, pair_fwd, pair_bwd)
ctx.mask_width = mask_width ctx.mask_width = mask_width
ctx.mask_out = mask_out ctx.mask_out = mask_out
ctx.timer = timer
ctx.pair_mask_fwd_splits = pair_mask_fwd_splits ctx.pair_mask_fwd_splits = pair_mask_fwd_splits
ctx.mask_argsort_fwd_splits = mask_argsort_fwd_splits ctx.mask_argsort_fwd_splits = mask_argsort_fwd_splits
ctx.pair_mask_bwd_splits = pair_mask_bwd_splits ctx.pair_mask_bwd_splits = pair_mask_bwd_splits
...@@ -130,8 +163,9 @@ class SparseImplicitGemmFunction(Function): ...@@ -130,8 +163,9 @@ class SparseImplicitGemmFunction(Function):
# num_activate_out = ctx.num_activate_out # num_activate_out = ctx.num_activate_out
masks = ctx.masks masks = ctx.masks
is_subm = ctx.is_subm is_subm = ctx.is_subm
timer = ctx.timer
input_bp, filters_bp = ops.implicit_gemm_backward(features, input_bp, filters_bp = ops.implicit_gemm_backward(
features,
filters, filters,
grad_output, grad_output,
pair_fwd, pair_fwd,
...@@ -143,17 +177,26 @@ class SparseImplicitGemmFunction(Function): ...@@ -143,17 +177,26 @@ class SparseImplicitGemmFunction(Function):
mask_output_fwd=mask_out, mask_output_fwd=mask_out,
masks=masks, masks=masks,
mask_width=mask_width, mask_width=mask_width,
is_subm=is_subm) is_subm=is_subm,
None_9 = [None] * 10 timer=timer)
None_9 = [None] * 11
return (input_bp, filters_bp, *None_9) return (input_bp, filters_bp, *None_9)
class SubMConvFunction(Function): class SubMConvFunction(Function):
@staticmethod @staticmethod
@amp.custom_fwd(cast_inputs=torch.float16) @amp.custom_fwd(cast_inputs=torch.float16)
def forward(ctx, features, filters, indice_pairs, indice_pair_num, def forward(ctx,
num_activate_out, algo): features,
filters,
indice_pairs,
indice_pair_num,
num_activate_out,
algo,
timer: CUDAKernelTimer = CUDAKernelTimer(False)):
ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters) ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters)
ctx.algo = algo ctx.algo = algo
ctx.timer = timer
return ops.indice_conv(features, return ops.indice_conv(features,
filters, filters,
indice_pairs, indice_pairs,
...@@ -161,13 +204,16 @@ class SubMConvFunction(Function): ...@@ -161,13 +204,16 @@ class SubMConvFunction(Function):
num_activate_out, num_activate_out,
False, False,
True, True,
algo=algo) algo=algo,
timer=timer)
@staticmethod @staticmethod
@once_differentiable @once_differentiable
@amp.custom_bwd @amp.custom_bwd
def backward(ctx, grad_output): def backward(ctx, grad_output):
indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
timer = ctx.timer
input_bp, filters_bp = ops.indice_conv_backward(features, input_bp, filters_bp = ops.indice_conv_backward(features,
filters, filters,
grad_output, grad_output,
...@@ -175,9 +221,10 @@ class SubMConvFunction(Function): ...@@ -175,9 +221,10 @@ class SubMConvFunction(Function):
indice_pair_num, indice_pair_num,
False, False,
True, True,
algo=ctx.algo) algo=ctx.algo,
timer=timer)
return input_bp, filters_bp, None, None, None, None return input_bp, filters_bp, None, None, None, None, None
class SparseMaxPoolFunction(Function): class SparseMaxPoolFunction(Function):
...@@ -199,12 +246,14 @@ class SparseMaxPoolFunction(Function): ...@@ -199,12 +246,14 @@ class SparseMaxPoolFunction(Function):
indice_pairs, indice_pair_num) indice_pairs, indice_pair_num)
return input_bp, None, None, None return input_bp, None, None, None
class SparseMaxPoolImplicitGemmFunction(Function): class SparseMaxPoolImplicitGemmFunction(Function):
@staticmethod @staticmethod
@amp.custom_fwd(cast_inputs=torch.float16) @amp.custom_fwd(cast_inputs=torch.float16)
def forward(ctx, features: torch.Tensor, indice_pairs_fwd: torch.Tensor, indice_pairs_bwd: torch.Tensor, def forward(ctx, features: torch.Tensor, indice_pairs_fwd: torch.Tensor,
num_activate_out: int): indice_pairs_bwd: torch.Tensor, num_activate_out: int):
out = ops.indice_maxpool_implicit_gemm(features, indice_pairs_fwd, num_activate_out) out = ops.indice_maxpool_implicit_gemm(features, indice_pairs_fwd,
num_activate_out)
ctx.save_for_backward(indice_pairs_bwd, features, out) ctx.save_for_backward(indice_pairs_bwd, features, out)
return out return out
...@@ -213,10 +262,11 @@ class SparseMaxPoolImplicitGemmFunction(Function): ...@@ -213,10 +262,11 @@ class SparseMaxPoolImplicitGemmFunction(Function):
@amp.custom_bwd @amp.custom_bwd
def backward(ctx, grad_output): def backward(ctx, grad_output):
indice_pairs_bwd, features, out = ctx.saved_tensors indice_pairs_bwd, features, out = ctx.saved_tensors
input_bp = ops.indice_maxpool_implicit_gemm_backward(features, out, grad_output, input_bp = ops.indice_maxpool_implicit_gemm_backward(
indice_pairs_bwd) features, out, grad_output, indice_pairs_bwd)
return input_bp, None, None, None return input_bp, None, None, None
indice_conv = SparseConvFunction.apply indice_conv = SparseConvFunction.apply
implicit_gemm = SparseImplicitGemmFunction.apply implicit_gemm = SparseImplicitGemmFunction.apply
indice_inverse_conv = SparseInverseConvFunction.apply indice_inverse_conv = SparseInverseConvFunction.apply
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys import sys
import time import time
from collections import OrderedDict from collections import OrderedDict
...@@ -53,6 +52,7 @@ class SparseModule(nn.Module): ...@@ -53,6 +52,7 @@ class SparseModule(nn.Module):
def __init__(self, name=None): def __init__(self, name=None):
super().__init__() super().__init__()
self.name = name self.name = name
self._sparse_unique_name = ""
class SparseSequential(SparseModule): class SparseSequential(SparseModule):
...@@ -143,3 +143,8 @@ class SparseSequential(SparseModule): ...@@ -143,3 +143,8 @@ class SparseSequential(SparseModule):
input = module(input) input = module(input)
return input return input
def assign_name_for_sparse_modules(module: nn.Module):
for k, n in module.named_modules():
if isinstance(n, SparseModule):
n._sparse_unique_name = k
...@@ -26,14 +26,19 @@ from spconv.pytorch.cppcore import torch_tensor_to_tv, get_current_stream ...@@ -26,14 +26,19 @@ from spconv.pytorch.cppcore import torch_tensor_to_tv, get_current_stream
from spconv.core_cc.csrc.sparse.all import SpconvOps from spconv.core_cc.csrc.sparse.all import SpconvOps
import spconv.core_cc as _ext import spconv.core_cc as _ext
from spconv.utils import nullcontext
if hasattr(_ext, "cumm"): if hasattr(_ext, "cumm"):
CPU_ONLY_BUILD = False
from spconv.algo import GEMM, CONV # , GATHER, SCATTER from spconv.algo import GEMM, CONV # , GATHER, SCATTER
else: else:
CPU_ONLY_BUILD = True
GEMM = None GEMM = None
CONV = None CONV = None
import time import time
from spconv.constants import FILTER_HWIO from spconv.constants import FILTER_HWIO
from cumm.gemm import codeops from cumm.gemm import codeops
from spconv.tools import CUDAKernelTimer
DEBUG = False DEBUG = False
...@@ -240,7 +245,8 @@ def get_indice_pairs(indices: torch.Tensor, ...@@ -240,7 +245,8 @@ def get_indice_pairs(indices: torch.Tensor,
return out_inds, pair, indice_num_per_loc return out_inds, pair, indice_num_per_loc
def get_indice_pairs_implicit_gemm(indices: torch.Tensor, def get_indice_pairs_implicit_gemm(
indices: torch.Tensor,
batch_size: int, batch_size: int,
spatial_shape: List[int], spatial_shape: List[int],
algo: ConvAlgo, algo: ConvAlgo,
...@@ -252,7 +258,8 @@ def get_indice_pairs_implicit_gemm(indices: torch.Tensor, ...@@ -252,7 +258,8 @@ def get_indice_pairs_implicit_gemm(indices: torch.Tensor,
subm: bool = False, subm: bool = False,
transpose: bool = False, transpose: bool = False,
is_train: bool = True, is_train: bool = True,
alloc: Optional[ThrustSortAllocator] = None): alloc: Optional[ThrustSortAllocator] = None,
timer: CUDAKernelTimer = CUDAKernelTimer(False)):
""" """
Why return tuple? because pytorch seems don't support custom object in autograd. Why return tuple? because pytorch seems don't support custom object in autograd.
return: ( return: (
...@@ -336,7 +343,7 @@ def get_indice_pairs_implicit_gemm(indices: torch.Tensor, ...@@ -336,7 +343,7 @@ def get_indice_pairs_implicit_gemm(indices: torch.Tensor,
out_inds_tv = torch_tensor_to_tv(out_inds) out_inds_tv = torch_tensor_to_tv(out_inds)
hashdata_tv = torch_tensor_to_tv(hashdata, dtype=tv.custom64) hashdata_tv = torch_tensor_to_tv(hashdata, dtype=tv.custom64)
pair_mask_tv = torch_tensor_to_tv(pair_mask, dtype=tv.uint32) pair_mask_tv = torch_tensor_to_tv(pair_mask, dtype=tv.uint32)
with timer.record("gen_subm_inds", stream):
SpconvOps.generate_subm_conv_inds(inds_tv, SpconvOps.generate_subm_conv_inds(inds_tv,
hashdata_tv, hashdata_tv,
pair_tv, pair_tv,
...@@ -358,12 +365,14 @@ def get_indice_pairs_implicit_gemm(indices: torch.Tensor, ...@@ -358,12 +365,14 @@ def get_indice_pairs_implicit_gemm(indices: torch.Tensor,
mask_argsort_tv = torch_tensor_to_tv(mask_argsort) mask_argsort_tv = torch_tensor_to_tv(mask_argsort)
if alloc is None: if alloc is None:
alloc = ThrustSortAllocator(indices.device) alloc = ThrustSortAllocator(indices.device)
with timer.record("gen_subm_inds_sort", stream):
for j in range(mask_split_count): for j in range(mask_split_count):
# thrust don't provide two-step sort (first step return workspace size) # thrust don't provide two-step sort (first step return workspace size)
# so I use this stupid hack to use torch allocator without touch # so I use this stupid hack to use torch allocator without touch
# pytorch binary (c++). # pytorch binary (c++).
# f**k thrust # f**k thrust
SpconvOps.sort_1d_by_key_allocator(pair_mask_tv[j], alloc.alloc, SpconvOps.sort_1d_by_key_allocator(pair_mask_tv[j],
alloc.alloc,
mask_argsort_tv[j], stream) mask_argsort_tv[j], stream)
# CONV.stream_synchronize(stream) # CONV.stream_synchronize(stream)
pair_mask_in_splits = [pair_mask[i] for i in range(mask_split_count)] pair_mask_in_splits = [pair_mask[i] for i in range(mask_split_count)]
...@@ -391,7 +400,7 @@ def get_indice_pairs_implicit_gemm(indices: torch.Tensor, ...@@ -391,7 +400,7 @@ def get_indice_pairs_implicit_gemm(indices: torch.Tensor,
dtype=indices.dtype, dtype=indices.dtype,
device=indices.device) device=indices.device)
indice_pairs_uniq_tv = torch_tensor_to_tv(indice_pairs_uniq) indice_pairs_uniq_tv = torch_tensor_to_tv(indice_pairs_uniq)
with timer.record("gen_conv_inds_stage1", stream):
SpconvOps.generate_conv_inds_mask_stage1(inds_tv, SpconvOps.generate_conv_inds_mask_stage1(inds_tv,
pair_bwd_tv, pair_bwd_tv,
indice_pairs_uniq_tv, indice_pairs_uniq_tv,
...@@ -452,7 +461,7 @@ def get_indice_pairs_implicit_gemm(indices: torch.Tensor, ...@@ -452,7 +461,7 @@ def get_indice_pairs_implicit_gemm(indices: torch.Tensor,
CONV.stream_synchronize(stream) CONV.stream_synchronize(stream)
print("REGU_S2_PREPARE", time.time() - t) print("REGU_S2_PREPARE", time.time() - t)
t = time.time() t = time.time()
with timer.record("gen_conv_inds_stage2", stream):
SpconvOps.generate_conv_inds_mask_stage2(inds_tv, SpconvOps.generate_conv_inds_mask_stage2(inds_tv,
hashdata_tv, hashdata_tv,
pair_fwd_tv, pair_fwd_tv,
...@@ -492,7 +501,7 @@ def get_indice_pairs_implicit_gemm(indices: torch.Tensor, ...@@ -492,7 +501,7 @@ def get_indice_pairs_implicit_gemm(indices: torch.Tensor,
mask_argsort_bwd_tv = torch_tensor_to_tv(mask_argsort_bwd) mask_argsort_bwd_tv = torch_tensor_to_tv(mask_argsort_bwd)
if alloc is None: if alloc is None:
alloc = ThrustSortAllocator(indices.device) alloc = ThrustSortAllocator(indices.device)
with timer.record("gen_conv_inds_sort", stream):
if is_mask_split: if is_mask_split:
for j in range(mask_split_count): for j in range(mask_split_count):
mask_tv = tv.from_numpy(masks[j]) mask_tv = tv.from_numpy(masks[j])
...@@ -530,24 +539,23 @@ def get_indice_pairs_implicit_gemm(indices: torch.Tensor, ...@@ -530,24 +539,23 @@ def get_indice_pairs_implicit_gemm(indices: torch.Tensor,
if not is_train: if not is_train:
SpconvOps.sort_1d_by_key_allocator(pair_mask_fwd_tv[0], SpconvOps.sort_1d_by_key_allocator(pair_mask_fwd_tv[0],
alloc.alloc, alloc.alloc,
mask_argsort_fwd_tv[0], stream) mask_argsort_fwd_tv[0],
stream)
else: else:
if pair_mask_bwd_tv.dim(1) > pair_mask_fwd_tv.dim(1): if pair_mask_bwd_tv.dim(1) > pair_mask_fwd_tv.dim(1):
SpconvOps.sort_1d_by_key_allocator(pair_mask_bwd_tv[0], SpconvOps.sort_1d_by_key_allocator(
alloc.alloc, pair_mask_bwd_tv[0], alloc.alloc,
mask_argsort_bwd_tv[0], mask_argsort_bwd_tv[0], stream)
stream) SpconvOps.sort_1d_by_key_allocator(
SpconvOps.sort_1d_by_key_allocator(pair_mask_fwd_tv[0], pair_mask_fwd_tv[0], alloc.alloc,
alloc.alloc,
mask_argsort_fwd_tv[0], stream) mask_argsort_fwd_tv[0], stream)
else: else:
SpconvOps.sort_1d_by_key_allocator(pair_mask_fwd_tv[0], SpconvOps.sort_1d_by_key_allocator(
alloc.alloc, pair_mask_fwd_tv[0], alloc.alloc,
mask_argsort_fwd_tv[0], stream) mask_argsort_fwd_tv[0], stream)
SpconvOps.sort_1d_by_key_allocator(pair_mask_bwd_tv[0], SpconvOps.sort_1d_by_key_allocator(
alloc.alloc, pair_mask_bwd_tv[0], alloc.alloc,
mask_argsort_bwd_tv[0], mask_argsort_bwd_tv[0], stream)
stream)
if DEBUG: if DEBUG:
CONV.stream_synchronize(stream) CONV.stream_synchronize(stream)
print("REGU_S2_FINISH", time.time() - t) print("REGU_S2_FINISH", time.time() - t)
...@@ -587,7 +595,8 @@ def indice_conv(features: torch.Tensor, ...@@ -587,7 +595,8 @@ def indice_conv(features: torch.Tensor,
num_activate_out: int, num_activate_out: int,
inverse: bool = False, inverse: bool = False,
subm: bool = False, subm: bool = False,
algo: ConvAlgo = ConvAlgo.Native): algo: ConvAlgo = ConvAlgo.Native,
timer: CUDAKernelTimer = CUDAKernelTimer(False)):
# filters: RSKC # filters: RSKC
# stream = get_current_stream() # stream = get_current_stream()
# CONV.stream_synchronize(stream) # CONV.stream_synchronize(stream)
...@@ -717,7 +726,7 @@ def indice_conv(features: torch.Tensor, ...@@ -717,7 +726,7 @@ def indice_conv(features: torch.Tensor,
stream=stream) stream=stream)
# CONV.stream_synchronize(stream) # CONV.stream_synchronize(stream)
# t = time.time() # t = time.time()
with timer.record("forward", stream):
for i, nhot in enumerate(indice_pair_num_cpu): for i, nhot in enumerate(indice_pair_num_cpu):
if subm and i == kv_center: if subm and i == kv_center:
continue continue
...@@ -770,7 +779,8 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -770,7 +779,8 @@ def indice_conv_backward(features: torch.Tensor,
indice_pair_num: torch.Tensor, indice_pair_num: torch.Tensor,
inverse: bool = False, inverse: bool = False,
subm: bool = False, subm: bool = False,
algo: ConvAlgo = ConvAlgo.Native): algo: ConvAlgo = ConvAlgo.Native,
timer: CUDAKernelTimer = CUDAKernelTimer(False)):
# print(out_bp.mean(), out_bp.max(), out_bp.min()) # print(out_bp.mean(), out_bp.max(), out_bp.min())
num_activate_out = out_bp.shape[0] num_activate_out = out_bp.shape[0]
...@@ -1046,12 +1056,16 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -1046,12 +1056,16 @@ def indice_conv_backward(features: torch.Tensor,
return (din, dfilters.reshape(filters_shape)) return (din, dfilters.reshape(filters_shape))
def implicit_gemm(features: torch.Tensor, filters: torch.Tensor, def implicit_gemm(features: torch.Tensor,
filters: torch.Tensor,
pair_fwd: torch.Tensor, pair_fwd: torch.Tensor,
pair_mask_fwd_splits: List[torch.Tensor], pair_mask_fwd_splits: List[torch.Tensor],
mask_argsort_fwd_splits: List[torch.Tensor], mask_argsort_fwd_splits: List[torch.Tensor],
num_activate_out: int, masks: List[np.ndarray], num_activate_out: int,
is_train: bool, is_subm: bool): masks: List[np.ndarray],
is_train: bool,
is_subm: bool,
timer: CUDAKernelTimer = CUDAKernelTimer(False)):
stream = get_current_stream() stream = get_current_stream()
# if DEBUG: # if DEBUG:
...@@ -1136,10 +1150,11 @@ def implicit_gemm(features: torch.Tensor, filters: torch.Tensor, ...@@ -1136,10 +1150,11 @@ def implicit_gemm(features: torch.Tensor, filters: torch.Tensor,
# CONV.stream_synchronize(stream) # CONV.stream_synchronize(stream)
# t = time.time() # t = time.time()
with timer.record("implicit_gemm", stream):
for j in range(num_split): for j in range(num_split):
beta = 0 if j == 0 else 1 beta = 0 if j == 0 else 1
CONV.run_with_tuned_result(tune_res, CONV.run_with_tuned_result(
tune_res,
ConvOpType.kForward, ConvOpType.kForward,
features_tv, features_tv,
filters_tv, filters_tv,
...@@ -1166,16 +1181,20 @@ def implicit_gemm(features: torch.Tensor, filters: torch.Tensor, ...@@ -1166,16 +1181,20 @@ def implicit_gemm(features: torch.Tensor, filters: torch.Tensor,
return out_features, mask_output_fwd, mask_width return out_features, mask_output_fwd, mask_width
def implicit_gemm_backward(features: torch.Tensor, filters: torch.Tensor, def implicit_gemm_backward(features: torch.Tensor,
out_bp: torch.Tensor, pair_fwd: torch.Tensor, filters: torch.Tensor,
out_bp: torch.Tensor,
pair_fwd: torch.Tensor,
pair_bwd: torch.Tensor, pair_bwd: torch.Tensor,
pair_mask_fwd_splits: List[torch.Tensor], pair_mask_fwd_splits: List[torch.Tensor],
pair_mask_bwd_splits: List[torch.Tensor], pair_mask_bwd_splits: List[torch.Tensor],
mask_argsort_fwd_splits: List[torch.Tensor], mask_argsort_fwd_splits: List[torch.Tensor],
mask_argsort_bwd_splits: List[torch.Tensor], mask_argsort_bwd_splits: List[torch.Tensor],
mask_output_fwd: torch.Tensor, mask_output_fwd: torch.Tensor,
masks: List[np.ndarray], mask_width: int, masks: List[np.ndarray],
is_subm: bool): mask_width: int,
is_subm: bool,
timer: CUDAKernelTimer = CUDAKernelTimer(False)):
# print(out_bp.mean(), out_bp.max(), out_bp.min()) # print(out_bp.mean(), out_bp.max(), out_bp.min())
if features.dtype == torch.int8 or features.dtype == torch.qint8: if features.dtype == torch.int8 or features.dtype == torch.qint8:
raise NotImplementedError("work in progress") raise NotImplementedError("work in progress")
...@@ -1287,6 +1306,7 @@ def implicit_gemm_backward(features: torch.Tensor, filters: torch.Tensor, ...@@ -1287,6 +1306,7 @@ def implicit_gemm_backward(features: torch.Tensor, filters: torch.Tensor,
dtype=torch.int8, dtype=torch.int8,
device=features.device) device=features.device)
workspace_tv = torch_tensor_to_tv(workspace) workspace_tv = torch_tensor_to_tv(workspace)
with timer.record("implicit_gemm_backward", stream):
for j in range(num_split): for j in range(num_split):
beta = 0 if j == 0 else 1 beta = 0 if j == 0 else 1
if is_subm: if is_subm:
...@@ -1310,7 +1330,8 @@ def implicit_gemm_backward(features: torch.Tensor, filters: torch.Tensor, ...@@ -1310,7 +1330,8 @@ def implicit_gemm_backward(features: torch.Tensor, filters: torch.Tensor,
mask_width=-1, mask_width=-1,
beta=beta, beta=beta,
stream=stream) stream=stream)
CONV.run_with_tuned_result(wgrad_tune_res, CONV.run_with_tuned_result(
wgrad_tune_res,
ConvOpType.kBackwardWeight, ConvOpType.kBackwardWeight,
features_tv, features_tv,
dfilters_tv, dfilters_tv,
...@@ -1445,4 +1466,3 @@ def indice_maxpool_implicit_gemm_backward(features, out_features, out_bp, ...@@ -1445,4 +1466,3 @@ def indice_maxpool_implicit_gemm_backward(features, out_features, out_bp,
out_bp_tv, din_tv, out_bp_tv, din_tv,
indice_pairs_tv, stream) indice_pairs_tv, stream)
return din return din
...@@ -24,11 +24,12 @@ from typing import List, Optional, Tuple, Union ...@@ -24,11 +24,12 @@ from typing import List, Optional, Tuple, Union
from spconv import pytorch as spconv from spconv import pytorch as spconv
from spconv.core import ConvAlgo from spconv.core import ConvAlgo
import spconv.pytorch.functional as Fsp from spconv.pytorch import functional as Fsp
from spconv.pytorch import ops from spconv.pytorch import ops
from spconv.pytorch.core import IndiceData, ImplicitGemmIndiceData from spconv.pytorch.core import IndiceData, ImplicitGemmIndiceData
from spconv.pytorch.modules import SparseModule from spconv.pytorch.modules import SparseModule
from spconv.cppconstants import CPU_ONLY_BUILD from spconv.cppconstants import CPU_ONLY_BUILD
from spconv.utils import nullcontext
class SparseMaxPool(SparseModule): class SparseMaxPool(SparseModule):
...@@ -128,11 +129,15 @@ class SparseMaxPool(SparseModule): ...@@ -128,11 +129,15 @@ class SparseMaxPool(SparseModule):
t = time.time() t = time.time()
out_padding = [0] * self.ndim out_padding = [0] * self.ndim
indice_dict = input.indice_dict.copy() indice_dict = input.indice_dict.copy()
profile_ctx = nullcontext()
if input._timer is not None and self._sparse_unique_name:
profile_ctx = input._timer.namespace(self._sparse_unique_name)
with profile_ctx:
if self.algo == ConvAlgo.Native: if self.algo == ConvAlgo.Native:
outids, indice_pairs, indice_pairs_num = ops.get_indice_pairs( outids, indice_pairs, indice_pairs_num = ops.get_indice_pairs(
indices, batch_size, spatial_shape, ConvAlgo.Native, indices, batch_size, spatial_shape, ConvAlgo.Native,
self.kernel_size, self.stride, self.padding, self.dilation, out_padding, self.kernel_size, self.stride, self.padding, self.dilation,
False) out_padding, False)
if input.benchmark: if input.benchmark:
torch.cuda.synchronize() torch.cuda.synchronize()
interval = time.time() - t interval = time.time() - t
...@@ -152,14 +157,17 @@ class SparseMaxPool(SparseModule): ...@@ -152,14 +157,17 @@ class SparseMaxPool(SparseModule):
algo=self.algo) algo=self.algo)
indice_dict[self.indice_key] = indice_data indice_dict[self.indice_key] = indice_data
else: else:
raise ValueError(f"indice key {self.indice_key} exists") raise ValueError(
f"indice key {self.indice_key} exists")
out_features = Fsp.indice_maxpool(features, out_features = Fsp.indice_maxpool(features,
indice_pairs.to(device), indice_pairs.to(device),
indice_pairs_num.to(device), indice_pairs_num.to(device),
outids.shape[0]) outids.shape[0])
else: else:
res = ops.get_indice_pairs_implicit_gemm(indices, with input._timer.namespace("gen_pairs"):
res = ops.get_indice_pairs_implicit_gemm(
indices,
batch_size, batch_size,
spatial_shape, spatial_shape,
self.algo, self.algo,
...@@ -170,7 +178,8 @@ class SparseMaxPool(SparseModule): ...@@ -170,7 +178,8 @@ class SparseMaxPool(SparseModule):
out_padding=out_padding, out_padding=out_padding,
subm=self.subm, subm=self.subm,
is_train=self.training, is_train=self.training,
alloc=input.thrust_allocator) alloc=input.thrust_allocator,
timer=input._timer)
outids = res[0] outids = res[0]
num_inds_per_loc = res[1] num_inds_per_loc = res[1]
pair_fwd = res[2] pair_fwd = res[2]
......
...@@ -15,18 +15,18 @@ ...@@ -15,18 +15,18 @@
import torch import torch
from torch.autograd import Function from torch.autograd import Function
import spconv.pytorch as spconv
#from torch.nn import Module #from torch.nn import Module
from spconv.pytorch.modules import SparseModule from spconv.pytorch.modules import SparseModule
from spconv.pytorch.core import SparseConvTensor from spconv.pytorch.core import SparseConvTensor
from typing import List from typing import List
class JoinTable(SparseModule): # Module): class JoinTable(SparseModule): # Module):
def forward(self, input: List[SparseConvTensor]): def forward(self, input: List[SparseConvTensor]):
output = spconv.SparseConvTensor( output = SparseConvTensor(torch.cat([i.features for i in input], 1),
torch.cat([i.features for i in input], 1), input[0].indices, input[0].indices, input[0].spatial_shape,
input[0].spatial_shape, input[0].batch_size, input[0].grid, input[0].voxel_num, input[0].batch_size, input[0].grid,
input[0].indice_dict) input[0].voxel_num, input[0].indice_dict)
output.benchmark_record = input[1].benchmark_record output.benchmark_record = input[1].benchmark_record
output.thrust_allocator = input[1].thrust_allocator output.thrust_allocator = input[1].thrust_allocator
return output return output
...@@ -37,10 +37,10 @@ class JoinTable(SparseModule): # Module): ...@@ -37,10 +37,10 @@ class JoinTable(SparseModule): # Module):
class AddTable(SparseModule): # Module): class AddTable(SparseModule): # Module):
def forward(self, input: List[SparseConvTensor]): def forward(self, input: List[SparseConvTensor]):
output = spconv.SparseConvTensor( output = SparseConvTensor(sum([i.features for i in input]),
sum([i.features for i in input]), input[0].indices, input[0].indices, input[0].spatial_shape,
input[0].spatial_shape, input[0].batch_size, input[0].grid, input[0].voxel_num, input[0].batch_size, input[0].grid,
input[0].indice_dict) input[0].voxel_num, input[0].indice_dict)
output.benchmark_record = input[1].benchmark_record output.benchmark_record = input[1].benchmark_record
output.thrust_allocator = input[1].thrust_allocator output.thrust_allocator = input[1].thrust_allocator
return output return output
......
...@@ -89,17 +89,18 @@ class PointToVoxel(object): ...@@ -89,17 +89,18 @@ class PointToVoxel(object):
voxels_tv = torch_tensor_to_tv(self.voxels) voxels_tv = torch_tensor_to_tv(self.voxels)
indices_tv = torch_tensor_to_tv(self.indices) indices_tv = torch_tensor_to_tv(self.indices)
num_per_voxel_tv = torch_tensor_to_tv(self.num_per_voxel) num_per_voxel_tv = torch_tensor_to_tv(self.num_per_voxel)
hashdata_tv = torch_tensor_to_tv(self.hashdata, hashdata_tv = torch_tensor_to_tv(
self.hashdata,
dtype=tv.custom128, dtype=tv.custom128,
shape=[self.hashdata.shape[0]]) shape=[self.hashdata.shape[0]])
point_indice_data_tv = torch_tensor_to_tv(self.point_indice_data) point_indice_data_tv = torch_tensor_to_tv(
self.point_indice_data)
res = SpconvOps.point2voxel_cuda(pc_tv, voxels_tv, indices_tv, res = SpconvOps.point2voxel_cuda(
num_per_voxel_tv, hashdata_tv, pc_tv, voxels_tv, indices_tv, num_per_voxel_tv,
point_indice_data_tv, self.vsize, hashdata_tv, point_indice_data_tv, self.vsize,
self.grid_size, self.grid_stride, self.grid_size, self.grid_stride, self.coors_range,
self.coors_range, empty_mean, empty_mean, clear_voxels, stream)
clear_voxels, stream)
num_voxels = res[0].shape[0] num_voxels = res[0].shape[0]
else: else:
pc_tv = torch_tensor_to_tv(pc) pc_tv = torch_tensor_to_tv(pc)
...@@ -111,8 +112,9 @@ class PointToVoxel(object): ...@@ -111,8 +112,9 @@ class PointToVoxel(object):
res = SpconvOps.point2voxel_cpu(pc_tv, voxels_tv, indices_tv, res = SpconvOps.point2voxel_cpu(pc_tv, voxels_tv, indices_tv,
num_per_voxel_tv, hashdata_tv, num_per_voxel_tv, hashdata_tv,
self.vsize, self.grid_size, self.vsize, self.grid_size,
self.grid_stride, self.coors_range, self.grid_stride,
empty_mean, clear_voxels) self.coors_range, empty_mean,
clear_voxels)
num_voxels = res[0].shape[0] num_voxels = res[0].shape[0]
return (self.voxels[:num_voxels], self.indices[:num_voxels], return (self.voxels[:num_voxels], self.indices[:num_voxels],
......
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from spconv.cppconstants import CPU_ONLY_BUILD
import contextlib
from spconv.utils import nullcontext
if not CPU_ONLY_BUILD:
from cumm.tensorview import CUDAKernelTimer as _CUDAKernelTimer
class CUDAKernelTimer:
def __init__(self, enable: bool = True) -> None:
self.enable = enable and not CPU_ONLY_BUILD
if self.enable:
self._timer = _CUDAKernelTimer(enable)
else:
self._timer = None
@contextlib.contextmanager
def _namespace(self, name: str):
assert self._timer is not None
self._timer.push(name)
try:
yield
finally:
self._timer.pop()
@contextlib.contextmanager
def _record(self, name: str, stream: int = 0):
assert self._timer is not None
self._timer.push(name)
try:
self._timer.insert_pair("", "start", "stop")
self._timer.record("start", stream)
yield
self._timer.record("stop", stream)
finally:
self._timer.pop()
def namespace(self, name: str):
if self.enable:
return self._namespace(name)
else:
return nullcontext()
def record(self, name: str, stream: int = 0):
if self.enable:
return self._record(name, stream)
else:
return nullcontext()
def get_all_pair_time(self) -> Dict[str, float]:
if self.enable:
assert self._timer is not None
return self._timer.get_all_pair_duration()
else:
return {}
@staticmethod
def collect_by_name(name: str, res: Dict[str, float]):
filtered_res: Dict[str, float] = {}
for k, v in res.items():
k_split = k.split(".")
if name in k_split:
filtered_res[k] = v
return filtered_res
...@@ -14,17 +14,36 @@ ...@@ -14,17 +14,36 @@
import numpy as np import numpy as np
from cumm import tensorview as tv from cumm import tensorview as tv
from contextlib import AbstractContextManager
from spconv.cppconstants import CPU_ONLY_BUILD
from spconv.core_cc.csrc.sparse.all.ops_cpu1d import Point2VoxelCPU as Point2VoxelCPU1d from spconv.core_cc.csrc.sparse.all.ops_cpu1d import Point2VoxelCPU as Point2VoxelCPU1d
from spconv.core_cc.csrc.sparse.all.ops_cpu2d import Point2VoxelCPU as Point2VoxelCPU2d from spconv.core_cc.csrc.sparse.all.ops_cpu2d import Point2VoxelCPU as Point2VoxelCPU2d
from spconv.core_cc.csrc.sparse.all.ops_cpu3d import Point2VoxelCPU as Point2VoxelCPU3d from spconv.core_cc.csrc.sparse.all.ops_cpu3d import Point2VoxelCPU as Point2VoxelCPU3d
from spconv.core_cc.csrc.sparse.all.ops_cpu4d import Point2VoxelCPU as Point2VoxelCPU4d from spconv.core_cc.csrc.sparse.all.ops_cpu4d import Point2VoxelCPU as Point2VoxelCPU4d
import spconv.core_cc.csrc.sparse.all as __all
IS_CPU_ONLY_BUILD = hasattr(__all, "ops1d") if not CPU_ONLY_BUILD:
if IS_CPU_ONLY_BUILD:
from spconv.core_cc.csrc.sparse.all.ops1d import Point2Voxel as Point2VoxelGPU1d from spconv.core_cc.csrc.sparse.all.ops1d import Point2Voxel as Point2VoxelGPU1d
from spconv.core_cc.csrc.sparse.all.ops2d import Point2Voxel as Point2VoxelGPU2d from spconv.core_cc.csrc.sparse.all.ops2d import Point2Voxel as Point2VoxelGPU2d
from spconv.core_cc.csrc.sparse.all.ops3d import Point2Voxel as Point2VoxelGPU3d from spconv.core_cc.csrc.sparse.all.ops3d import Point2Voxel as Point2VoxelGPU3d
from spconv.core_cc.csrc.sparse.all.ops4d import Point2Voxel as Point2VoxelGPU4d from spconv.core_cc.csrc.sparse.all.ops4d import Point2Voxel as Point2VoxelGPU4d
class nullcontext(AbstractContextManager):
"""Context manager that does no additional processing.
Used as a stand-in for a normal context manager, when a particular
block of code is only sometimes used with a normal context manager:
cm = optional_cm if condition else nullcontext()
with cm:
# Perform operation, using optional_cm if condition is True
"""
def __init__(self, enter_result=None):
self.enter_result = enter_result
def __enter__(self):
return self.enter_result
def __exit__(self, *excinfo):
pass
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
STR = """
BWG 0.0008761882781982422
BWG 0.0008311271667480469
BWG 0.002079486846923828
BWG 0.002329587936401367
BWG 0.0025458335876464844
BWG 0.0026700496673583984
BWG 0.002583742141723633
BWG 0.0025262832641601562
BWG 0.003481149673461914
BWG 0.003238201141357422
BWG 0.005095958709716797
BWG 0.0037899017333984375
BWG 0.003931283950805664
BWG 0.003300189971923828
"""
"""
0.003921985626220703
0.0049707889556884766
0.0052530765533447266
0.0060312747955322266
0.0036766529083251953
0.00421142578125
0.002129793167114258
0.0023038387298583984
0.0013151168823242188
0.0015285015106201172
0.0008392333984375
0.0008127689361572266
0.0002486705780029297
0.00030994415283203125
"""
STR1 = """
SUBM 0.0005137920379638672
F 0.0012662410736083984
F 0.0016875267028808594
REGU 0.0009055137634277344
M 0.0009114742279052734
SUBM 0.00037789344787597656
F 0.0020329952239990234
F 0.001947641372680664
REGU 0.0009374618530273438
M 0.00045609474182128906
SUBM 0.0009856224060058594
F 0.0009992122650146484
F 0.0010600090026855469
REGU 0.0006346702575683594
M 0.0004057884216308594
SUBM 0.0006394386291503906
F 0.0008478164672851562
F 0.0008838176727294922
REGU 0.0007183551788330078
M 0.00025177001953125
SUBM 0.0009539127349853516
F 0.0009481906890869141
F 0.0010502338409423828
REGU 0.0007147789001464844
M 0.000274658203125
SUBM 0.0007004737854003906
F 0.0009715557098388672
F 0.0012331008911132812
REGU 0.0008800029754638672
M 0.0002167224884033203
SUBM 0.00045108795166015625
F 0.0006735324859619141
F 0.0008375644683837891
"""
STR2 = """
F Turing_f16f16f16f16f16tnt_m32n64k32m32n32k16A0T1688_NS00_C3_01LLL_1 0.0007038116455078125
F Turing_f16f16f16f16f16tnt_m32n64k32m32n32k16A1T1688_NS00_C3_01LLL_1 0.0007627010345458984
F Turing_f16f16f16f16f16tnt_m64n128k32m32n64k32A1T1688_NS00_C3_01LLL_1 0.0007650852203369141
F Turing_f16f16f16f16f16tnt_m64n128k32m32n64k32A1T1688_NS00_C3_01LLL_1 0.0008864402770996094
F Turing_f16f16f16f16f16tnt_m64n128k32m32n64k32A1T1688_NS00_C3_01LLL_1 0.0004017353057861328
F Turing_f16f16f16f16f16tnt_m32n128k64m32n32k32A1T1688_NS00_C3_01LLL_1 0.0006165504455566406
F Turing_f16f16f16f16f16tnt_m64n64k32m32n32k32A1T1688_NS00_C3_01LLL_1 0.0005872249603271484
F Turing_f16f16f16f16f16tnt_m64n64k32m32n32k32A1T1688_NS00_C3_01LLL_1 0.0006289482116699219
F Turing_f16f16f16f16f16tnt_m32n64k32m32n32k16A1T1688_NS00_C3_01LLL_1 0.0002968311309814453
F Turing_f16f16f16f16f16tnt_m64n64k32m32n32k32A1T1688_NS00_C3_01LLL_1 0.0003299713134765625
F Turing_f16f16f16f16f16tnt_m64n128k64m32n64k32A1T1688_NS00_C3_01LLL_1 0.0002288818359375
F Turing_f16f16f16f16f16tnt_m32n64k32m32n32k16A1T1688_NS00_C3_01LLL_1 0.0002830028533935547
F Turing_f16f16f16f16f16tnt_m32n64k32m32n32k16A1T1688_NS00_C3_01LLL_1 0.0001780986785888672
F Turing_f16f16f16f16f16tnt_m32n64k32m32n32k16A1T1688_NS00_C3_01LLL_1 0.0003058910369873047
"""
def _handle_lines(s: str):
arr = s.split(" ")
return (arr[0], float(arr[-1]))
from cumm.gemm.codeops import group_by
def print_str(s: str):
nums = list(map(_handle_lines, s.strip().split("\n")))
num_dict = group_by(lambda x: x[0], nums)
num_dict_ = {k: sum([vv[1] for vv in v]) for k, v in num_dict.items()}
print(num_dict_)
print_str(STR1)
print_str(STR2)
\ No newline at end of file
...@@ -23,6 +23,8 @@ from spconv.core import ConvAlgo ...@@ -23,6 +23,8 @@ from spconv.core import ConvAlgo
import spconv.pytorch as spconv import spconv.pytorch as spconv
from spconv.utils import Point2VoxelCPU3d from spconv.utils import Point2VoxelCPU3d
def waymo_data(batch_size=1): def waymo_data(batch_size=1):
gen = Point2VoxelCPU3d([0.1, 0.1, 0.1], [-80, -80, -2, 80, 80, 6], 3, gen = Point2VoxelCPU3d([0.1, 0.1, 0.1], [-80, -80, -2, 80, 80, 6], 3,
150000, 1) 150000, 1)
...@@ -68,7 +70,6 @@ class Net(nn.Module): ...@@ -68,7 +70,6 @@ class Net(nn.Module):
# nn.BatchNorm1d(32), # nn.BatchNorm1d(32),
# nn.ReLU(), # nn.ReLU(),
# spconv.SparseConv3d(64, 64, 2, 2, bias=False, indice_key="m0"), # spconv.SparseConv3d(64, 64, 2, 2, bias=False, indice_key="m0"),
spconv.SparseMaxPool3d(2, 2, algo=pool_algo), spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SubMConv3d(64, spconv.SubMConv3d(64,
96, 96,
...@@ -101,7 +102,6 @@ class Net(nn.Module): ...@@ -101,7 +102,6 @@ class Net(nn.Module):
# nn.BatchNorm1d(128), # nn.BatchNorm1d(128),
# nn.ReLU(), # nn.ReLU(),
# spconv.SparseConv3d(128, 128, 2, 2, bias=False, indice_key="m2"), # spconv.SparseConv3d(128, 128, 2, 2, bias=False, indice_key="m2"),
spconv.SparseMaxPool3d(2, 2, algo=pool_algo), spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SubMConv3d(128, spconv.SubMConv3d(128,
160, 160,
...@@ -118,7 +118,6 @@ class Net(nn.Module): ...@@ -118,7 +118,6 @@ class Net(nn.Module):
# nn.BatchNorm1d(128), # nn.BatchNorm1d(128),
# nn.ReLU(), # nn.ReLU(),
# spconv.SparseConv3d(160, 160, 2, 2, bias=False, indice_key="m3"), # spconv.SparseConv3d(160, 160, 2, 2, bias=False, indice_key="m3"),
spconv.SparseMaxPool3d(2, 2, algo=pool_algo), spconv.SparseMaxPool3d(2, 2, algo=pool_algo),
spconv.SubMConv3d(160, spconv.SubMConv3d(160,
192, 192,
...@@ -136,7 +135,6 @@ class Net(nn.Module): ...@@ -136,7 +135,6 @@ class Net(nn.Module):
# nn.ReLU(), # nn.ReLU(),
spconv.SparseMaxPool3d(2, 2, indice_key="m4", algo=pool_algo), spconv.SparseMaxPool3d(2, 2, indice_key="m4", algo=pool_algo),
# spconv.SparseConv3d(192, 192, 2, 2, bias=False, indice_key="m4"), # spconv.SparseConv3d(192, 192, 2, 2, bias=False, indice_key="m4"),
spconv.SubMConv3d(192, spconv.SubMConv3d(192,
224, 224,
3, 3,
...@@ -174,7 +172,6 @@ class Net(nn.Module): ...@@ -174,7 +172,6 @@ class Net(nn.Module):
# # nn.ReLU(), # # nn.ReLU(),
# spconv.SparseInverseConv3d(128, 64, 2, indice_key="m4", bias=False, algo=algo), # spconv.SparseInverseConv3d(128, 64, 2, indice_key="m4", bias=False, algo=algo),
) )
max_batch_size = 1 max_batch_size = 1
# grid (dense map) is used for indice generation. use pre-allocated grid can run faster. # grid (dense map) is used for indice generation. use pre-allocated grid can run faster.
...@@ -183,16 +180,25 @@ class Net(nn.Module): ...@@ -183,16 +180,25 @@ class Net(nn.Module):
# self.grid = None # self.grid = None
self.shape = shape self.shape = shape
def forward(self, features, coors, batch_size): def forward(self, features, coors, batch_size, enable_timer: bool = False):
x = spconv.SparseConvTensor(features, coors, self.shape, batch_size, x = spconv.SparseConvTensor(features,
self.grid) coors,
self.shape,
batch_size,
self.grid,
enable_timer=enable_timer)
return self.net(x) return self.net(x)
class Net2(nn.Module): class Net2(nn.Module):
def __init__(self, shape, algo): def __init__(self, shape, algo):
super().__init__() super().__init__()
self.net = spconv.SparseSequential( self.net = spconv.SparseSequential(
spconv.SubMConv3d(3, 128, 3, bias=False, indice_key="c0", spconv.SubMConv3d(3,
128,
3,
bias=False,
indice_key="c0",
algo=algo), algo=algo),
# spconv.SubMConv3d(32, # spconv.SubMConv3d(32,
# 32, # 32,
...@@ -240,6 +246,7 @@ class Net2(nn.Module): ...@@ -240,6 +246,7 @@ class Net2(nn.Module):
self.grid) self.grid)
return self.net(x) return self.net(x)
import numpy as np import numpy as np
from cumm import tensorview as tv from cumm import tensorview as tv
from spconv.core_cc.csrc.sparse.all import SpconvOps from spconv.core_cc.csrc.sparse.all import SpconvOps
...@@ -248,6 +255,7 @@ import torch ...@@ -248,6 +255,7 @@ import torch
from spconv.pytorch.cppcore import torch_tensor_to_tv from spconv.pytorch.cppcore import torch_tensor_to_tv
def sort_bench(): def sort_bench():
with open("/home/yy/asd.pkl", "rb") as f: with open("/home/yy/asd.pkl", "rb") as f:
a_th = pickle.load(f) a_th = pickle.load(f)
...@@ -262,6 +270,7 @@ def sort_bench(): ...@@ -262,6 +270,7 @@ def sort_bench():
a_tv_1 = a_tv.clone() a_tv_1 = a_tv.clone()
SpconvOps.sort_1d_by_key(a_tv_1[0], mask_argsort_tv[0]) SpconvOps.sort_1d_by_key(a_tv_1[0], mask_argsort_tv[0])
def main(): def main():
import pickle import pickle
np.random.seed(50051) np.random.seed(50051)
...@@ -280,24 +289,55 @@ def main(): ...@@ -280,24 +289,55 @@ def main():
voxels_th = torch.from_numpy(voxels).to(device).to(dtype) voxels_th = torch.from_numpy(voxels).to(device).to(dtype)
coors_th = torch.from_numpy(coors).to(device).int() coors_th = torch.from_numpy(coors).to(device).int()
voxels_th.requires_grad = True voxels_th.requires_grad = True
algo = spconv.ConvAlgo.MaskImplicitGemm algo = spconv.ConvAlgo.Native
# 3080 Laptop
# MaskImpGemm: 11.2ms
# MaskSplitImpGemm: 12.2ms
# Native: 13.7ms
# F32
# MaskSplitImpGemm: 22ms
# MaskImplicitGemm: 23.5ms
# Native: 21.7ms
# Pure Gemm
# Native: 6.6ms
# MaskImpGemm: 4.3ms
# MaskSplitImpGemm: 4.0ms
# F16 Bwd
# MaskSplitImpGemm: 12.2ms
# MaskImpGemm: 13.8ms
# Native: 25.2ms
# F32 Bwd
# Native: 41.9ms
# MaskImpGemm: 51.0ms
# MaskSplitImpGemm: 41.1ms
# algo = None
net = Net(spatial_shape, algo).to(device).eval().to(dtype).train() net = Net(spatial_shape, algo).to(device).eval().to(dtype).train()
spconv.assign_name_for_sparse_modules(net)
print(coors_th.shape) print(coors_th.shape)
out = net(voxels_th, coors_th, 1) out = net(voxels_th, coors_th, 1)
print(out.spatial_shape) print(out.spatial_shape)
print(voxels.mean(), voxels.max(), voxels.min()) print(voxels.mean(), voxels.max(), voxels.min())
dout = np.random.uniform(-0.2, 0.2, dout = np.random.uniform(-0.2, 0.2, out.features.shape).astype(np.float32)
out.features.shape).astype(np.float32)
dout_t = torch.from_numpy(dout).to(device).to(dtype) dout_t = torch.from_numpy(dout).to(device).to(dtype)
print(out.spatial_shape, out.features.mean(), out.features.max(), out.features.min()) print(out.spatial_shape, out.features.mean(), out.features.max(),
out.features.min())
times = [] times = []
with torch.no_grad(): with torch.no_grad():
for i in range(20): for i in range(20):
print("------------") print("------------")
torch.cuda.synchronize() torch.cuda.synchronize()
t = time.time() t = time.time()
out_nograd = net(voxels_th, coors_th, 1) out_nograd = net(voxels_th, coors_th, 1, True)
timer = out_nograd._timer
res = timer.collect_by_name("forward", timer.get_all_pair_time())
res2 = timer.collect_by_name("forward0", timer.get_all_pair_time())
print(sum(res.values()) + sum(res2.values()))
# print(timer.get_all_pair_time())
# print(sum(timer.get_all_pair_time().values()))
torch.cuda.synchronize() torch.cuda.synchronize()
# sort_bench() # sort_bench()
times.append(time.time() - t) times.append(time.time() - t)
...@@ -313,8 +353,8 @@ def main(): ...@@ -313,8 +353,8 @@ def main():
# torch.cuda.synchronize() # torch.cuda.synchronize()
# times.append(time.time() - t) # times.append(time.time() - t)
# print((net.grid == -1).float().sum(), net.grid.numel()) # # # print((net.grid == -1).float().sum(), net.grid.numel())
# print("spconv time", time.time() - t) # # # print("spconv time", time.time() - t)
# print("spconv bw time", np.mean(times[5:])) # print("spconv bw time", np.mean(times[5:]))
......
...@@ -30,6 +30,7 @@ from spconv.constants import FILTER_HWIO ...@@ -30,6 +30,7 @@ from spconv.constants import FILTER_HWIO
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False
class SparseConv3dTestTorch(nn.Module): class SparseConv3dTestTorch(nn.Module):
def __init__(self, def __init__(self,
num_layers, num_layers,
...@@ -363,7 +364,10 @@ class TestSpConv(TestCase): ...@@ -363,7 +364,10 @@ class TestSpConv(TestCase):
strides = [1, 2, 3] strides = [1, 2, 3]
paddings = [0, 1, 2] paddings = [0, 1, 2]
dilations = [1, 2, 3] dilations = [1, 2, 3]
algos = [ConvAlgo.Native, ConvAlgo.MaskImplicitGemm, ConvAlgo.MaskSplitImplicitGemm] algos = [
ConvAlgo.Native, ConvAlgo.MaskImplicitGemm,
ConvAlgo.MaskSplitImplicitGemm
]
algos = [ConvAlgo.MaskSplitImplicitGemm] algos = [ConvAlgo.MaskSplitImplicitGemm]
for dev, shape, bs, IC, OC, k, s, p, d, al in params_grid( for dev, shape, bs, IC, OC, k, s, p, d, al in params_grid(
...@@ -375,8 +379,16 @@ class TestSpConv(TestCase): ...@@ -375,8 +379,16 @@ class TestSpConv(TestCase):
device = torch.device(dev) device = torch.device(dev)
num_points = [1000] * bs num_points = [1000] * bs
dtype = torch.float32 dtype = torch.float32
net = SparseConv3dTestTorch(1, 3, shape, IC, OC, k, s, p, net = SparseConv3dTestTorch(1,
d, algo=al).to(device).to(dtype) 3,
shape,
IC,
OC,
k,
s,
p,
d,
algo=al).to(device).to(dtype)
net_ref = Conv3dTestTorch(1, 3, shape, IC, OC, k, s, p, net_ref = Conv3dTestTorch(1, 3, shape, IC, OC, k, s, p,
d).to(device).to(dtype) d).to(device).to(dtype)
...@@ -390,27 +402,32 @@ class TestSpConv(TestCase): ...@@ -390,27 +402,32 @@ class TestSpConv(TestCase):
indices_t = torch.from_numpy(indices).int().to(device) indices_t = torch.from_numpy(indices).int().to(device)
features_t = torch.from_numpy(features).to(device).to(dtype) features_t = torch.from_numpy(features).to(device).to(dtype)
features_t.requires_grad = True features_t.requires_grad = True
features_dense_t = torch.from_numpy(features_dense).to(device).to(dtype) features_dense_t = torch.from_numpy(features_dense).to(device).to(
dtype)
features_dense_t.requires_grad = True features_dense_t.requires_grad = True
if net.algo == ConvAlgo.Native: if net.algo == ConvAlgo.Native:
if FILTER_HWIO: if FILTER_HWIO:
filters = np.random.uniform(-1, 1, size=[k, k, k, IC, filters = np.random.uniform(-1, 1,
size=[k, k, k, IC,
OC]).astype(np.float32) OC]).astype(np.float32)
else: else:
filters = np.random.uniform(-1, 1, size=[k, k, k, OC, filters = np.random.uniform(-1, 1,
size=[k, k, k, OC,
IC]).astype(np.float32) IC]).astype(np.float32)
filters_t = torch.from_numpy(filters).to(device).to(dtype) filters_t = torch.from_numpy(filters).to(device).to(dtype)
if FILTER_HWIO: if FILTER_HWIO:
net_ref.net[0].weight.data[:] = filters_t.permute(4, 3, 0, 1, net_ref.net[0].weight.data[:] = filters_t.permute(
2).contiguous() 4, 3, 0, 1, 2).contiguous()
else: else:
net_ref.net[0].weight.data[:] = filters_t.permute(3, 4, 0, 1, net_ref.net[0].weight.data[:] = filters_t.permute(
2).contiguous() 3, 4, 0, 1, 2).contiguous()
else: else:
filters = np.random.uniform(-1, 1, size=[OC, k, k, k, IC]).astype(np.float32) filters = np.random.uniform(-1, 1,
size=[OC, k, k, k,
IC]).astype(np.float32)
filters_t = torch.from_numpy(filters).to(device).to(dtype) filters_t = torch.from_numpy(filters).to(device).to(dtype)
net_ref.net[0].weight.data[:] = filters_t.permute(0, 4, 1, 2, net_ref.net[0].weight.data[:] = filters_t.permute(
3).contiguous() 0, 4, 1, 2, 3).contiguous()
net.net[0].weight.data[:] = filters_t net.net[0].weight.data[:] = filters_t
out_ref = net_ref(features_dense_t) out_ref = net_ref(features_dense_t)
...@@ -446,7 +463,6 @@ class TestSpConv(TestCase): ...@@ -446,7 +463,6 @@ class TestSpConv(TestCase):
self.assertAllClose(dw, dw_ref, atol=1e-4) self.assertAllClose(dw, dw_ref, atol=1e-4)
self.assertAllClose(din_np, din_sparse_np, atol=1e-4) self.assertAllClose(din_np, din_sparse_np, atol=1e-4)
def testSpDeConv3d(self): def testSpDeConv3d(self):
np.random.seed(484) np.random.seed(484)
devices = ["cuda:0"] devices = ["cuda:0"]
...@@ -499,11 +515,11 @@ class TestSpConv(TestCase): ...@@ -499,11 +515,11 @@ class TestSpConv(TestCase):
filters_t = torch.from_numpy(filters).to(device) filters_t = torch.from_numpy(filters).to(device)
print(net_ref.net[0].weight.shape) print(net_ref.net[0].weight.shape)
if FILTER_HWIO: if FILTER_HWIO:
net_ref.net[0].weight.data[:] = filters_t.permute(3, 4, 0, 1, net_ref.net[0].weight.data[:] = filters_t.permute(
2).contiguous() 3, 4, 0, 1, 2).contiguous()
else: else:
net_ref.net[0].weight.data[:] = filters_t.permute(4, 3, 0, 1, net_ref.net[0].weight.data[:] = filters_t.permute(
2).contiguous() 4, 3, 0, 1, 2).contiguous()
net.net[0].weight.data[:] = filters_t net.net[0].weight.data[:] = filters_t
out_ref = net_ref(features_dense_t) out_ref = net_ref(features_dense_t)
out = net(features_t, indices_t, bs).dense() out = net(features_t, indices_t, bs).dense()
...@@ -532,7 +548,6 @@ class TestSpConv(TestCase): ...@@ -532,7 +548,6 @@ class TestSpConv(TestCase):
dw = dw.transpose(4, 3, 0, 1, 2) dw = dw.transpose(4, 3, 0, 1, 2)
self.assertAllClose(dw, dw_ref, atol=1e-4) self.assertAllClose(dw, dw_ref, atol=1e-4)
def testSpCpConv3d(self): def testSpCpConv3d(self):
np.random.seed(484) np.random.seed(484)
devices = ["cuda:0", "cpu:0"] devices = ["cuda:0", "cpu:0"]
......
2.1.3 2.1.5
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment