Commit 82fd7a8b authored by yan.yan's avatar yan.yan
Browse files

v2.1.5: add profile tool and python 3.6 for linux

parent f31eee3a
# Copyright 2021 Yan Yan # Copyright 2021 Yan Yan
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -15,27 +15,27 @@ ...@@ -15,27 +15,27 @@
import contextlib import contextlib
from cumm.gemm.core.metaarray import MetaArray, seq from cumm.gemm.core.metaarray import MetaArray, seq
from cumm import dtypes from cumm import dtypes
import pccm import pccm
from cumm.gemm.layout import TensorGeneric, to_stride from cumm.gemm.layout import TensorGeneric, to_stride
from cumm.common import TensorView, TensorViewHashKernel from cumm.common import TensorView, TensorViewHashKernel
from cumm.gemm import codeops from cumm.gemm import codeops
from typing import List from typing import List
from cumm.conv.params import ConvProblem from cumm.conv.params import ConvProblem
import numpy as np import numpy as np
class Point2VoxelCommon(pccm.ParameterizedClass): class Point2VoxelCommon(pccm.ParameterizedClass):
def __init__(self, dtype: dtypes.DType, ndim: int, zyx: bool = True): def __init__(self, dtype: dtypes.DType, ndim: int, zyx: bool = True):
super().__init__() super().__init__()
self.add_dependency(TensorView) self.add_dependency(TensorView)
self.dtype = dtype self.dtype = dtype
self.ndim = ndim self.ndim = ndim
self.zyx = zyx self.zyx = zyx
ret_str = f"std::array<int, {self.ndim}>" ret_str = f"std::array<int, {self.ndim}>"
retf_str = f"std::array<float, {self.ndim}>" retf_str = f"std::array<float, {self.ndim}>"
retf2_str = f"std::array<float, {self.ndim * 2}>" retf2_str = f"std::array<float, {self.ndim * 2}>"
self.calc_meta_ret = f"std::tuple<{retf_str}, {ret_str}, {ret_str}, {retf2_str}>" self.calc_meta_ret = f"std::tuple<{retf_str}, {ret_str}, {ret_str}, {retf2_str}>"
@pccm.pybind.mark
@pccm.static_function @pccm.static_function
def calc_meta_data(self): def calc_meta_data(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
...@@ -80,7 +80,8 @@ class Point2VoxelCommon(pccm.ParameterizedClass): ...@@ -80,7 +80,8 @@ class Point2VoxelCommon(pccm.ParameterizedClass):
retf_str = f"std::array<float, {self.ndim}>" retf_str = f"std::array<float, {self.ndim}>"
retf2_str = f"std::array<float, {self.ndim * 2}>" retf2_str = f"std::array<float, {self.ndim * 2}>"
return code.ret(f"std::tuple<{retf_str}, {ret_str}, {ret_str}, {retf2_str}>") return code.ret(
f"std::tuple<{retf_str}, {ret_str}, {ret_str}, {retf2_str}>")
@pccm.static_function @pccm.static_function
def array2tvarray(self): def array2tvarray(self):
...@@ -112,16 +113,21 @@ class Point2VoxelCommon(pccm.ParameterizedClass): ...@@ -112,16 +113,21 @@ class Point2VoxelCommon(pccm.ParameterizedClass):
""") """)
return code.ret("std::array<T, N>") return code.ret("std::array<T, N>")
class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
"""this class don't support multi-thread. """this class don't support multi-thread.
create p2v for every thread. create p2v for every thread.
""" """
def __init__(self, dtype: dtypes.DType, ndim: int, layout: TensorGeneric, zyx: bool = True): def __init__(self,
dtype: dtypes.DType,
ndim: int,
layout: TensorGeneric,
zyx: bool = True):
super().__init__() super().__init__()
self.add_dependency(TensorView, TensorViewHashKernel) self.add_dependency(TensorView, TensorViewHashKernel)
self.add_param_class("layout_ns", layout, "Layout") self.add_param_class("layout_ns", layout, "Layout")
self.dtype = dtype self.dtype = dtype
self.ndim = ndim self.ndim = ndim
self.zyx = zyx self.zyx = zyx
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
...@@ -142,7 +148,7 @@ class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -142,7 +148,7 @@ class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
point_xyz = f"{self.ndim - 1} - j" point_xyz = f"{self.ndim - 1} - j"
if not self.zyx: if not self.zyx:
point_xyz = f"j" point_xyz = f"j"
# if zyx, the coors_range and grid_bound is zyx too, # if zyx, the coors_range and grid_bound is zyx too,
# generated indices is zyx. # generated indices is zyx.
code.raw(f""" code.raw(f"""
for (int i : tv::KernelLoopX<int>(num_points)){{ for (int i : tv::KernelLoopX<int>(num_points)){{
...@@ -166,7 +172,7 @@ class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -166,7 +172,7 @@ class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
}} }}
}} }}
""") """)
return code return code
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def assign_table(self): def assign_table(self):
...@@ -190,7 +196,7 @@ class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -190,7 +196,7 @@ class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
}} }}
}} }}
""") """)
return code return code
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def generate_voxel(self): def generate_voxel(self):
...@@ -231,7 +237,7 @@ class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -231,7 +237,7 @@ class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
}} }}
}} }}
""") """)
return code return code
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def voxel_empty_fill_mean(self): def voxel_empty_fill_mean(self):
...@@ -263,7 +269,7 @@ class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -263,7 +269,7 @@ class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
}} }}
}} }}
""") """)
return code return code
@pccm.cuda.cuda_global_function @pccm.cuda.cuda_global_function
def limit_num_per_voxel_value(self): def limit_num_per_voxel_value(self):
...@@ -276,7 +282,8 @@ class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -276,7 +282,8 @@ class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
num_per_voxel[i] = count; num_per_voxel[i] = count;
}} }}
""") """)
return code return code
class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
def __init__(self, dtype: dtypes.DType, ndim: int, zyx: bool = True): def __init__(self, dtype: dtypes.DType, ndim: int, zyx: bool = True):
...@@ -286,14 +293,23 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -286,14 +293,23 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
self.add_param_class("p2v_c", self.p2v_c, "Point2VoxelCommon") self.add_param_class("p2v_c", self.p2v_c, "Point2VoxelCommon")
layout = TensorGeneric(ndim, True) layout = TensorGeneric(ndim, True)
self.add_param_class("layout_ns", layout, "Layout") self.add_param_class("layout_ns", layout, "Layout")
self.dtype = dtype self.dtype = dtype
self.ndim = ndim self.ndim = ndim
self.zyx = zyx self.zyx = zyx
cuda_funcs = [self.point_to_voxel_hash, self.point_to_voxel_hash_static] cuda_funcs = [
self.add_impl_only_param_class(cuda_funcs, "kernel", Point2VoxelKernel(dtype, ndim, layout, zyx)) self.point_to_voxel_hash, self.point_to_voxel_hash_static
]
self.add_pybind_member("hashdata", "tv::Tensor", readwrite=False, pyanno="cumm.tensorview.Tensor") self.add_impl_only_param_class(
self.add_pybind_member("point_indice_data", "tv::Tensor", readwrite=False, pyanno="cumm.tensorview.Tensor") cuda_funcs, "kernel", Point2VoxelKernel(dtype, ndim, layout, zyx))
self.add_pybind_member("hashdata",
"tv::Tensor",
readwrite=False,
pyanno="cumm.tensorview.Tensor")
self.add_pybind_member("point_indice_data",
"tv::Tensor",
readwrite=False,
pyanno="cumm.tensorview.Tensor")
self.add_pybind_member("voxels", "tv::Tensor", readwrite=False) self.add_pybind_member("voxels", "tv::Tensor", readwrite=False)
self.add_pybind_member("indices", "tv::Tensor", readwrite=False) self.add_pybind_member("indices", "tv::Tensor", readwrite=False)
...@@ -357,7 +373,7 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -357,7 +373,7 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
hashdata = tv::zeros({{1}}, tv::custom128, 0); hashdata = tv::zeros({{1}}, tv::custom128, 0);
point_indice_data = tv::zeros({{1}}, tv::int64, 0); point_indice_data = tv::zeros({{1}}, tv::int64, 0);
""") """)
return code return code
@pccm.pybind.mark @pccm.pybind.mark
@pccm.cuda.member_function @pccm.cuda.member_function
...@@ -439,13 +455,13 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -439,13 +455,13 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
""") """)
return code.ret("std::tuple<tv::Tensor, tv::Tensor, tv::Tensor>") return code.ret("std::tuple<tv::Tensor, tv::Tensor, tv::Tensor>")
@pccm.pybind.mark @pccm.pybind.mark
@pccm.cuda.static_function @pccm.cuda.static_function
def point_to_voxel_hash_static(self): def point_to_voxel_hash_static(self):
code = pccm.FunctionCode() code = pccm.FunctionCode()
code.arg("points", "tv::Tensor") code.arg("points", "tv::Tensor")
code.arg("voxels, indices, num_per_voxel, hashdata, point_indice_data", "tv::Tensor") code.arg("voxels, indices, num_per_voxel, hashdata, point_indice_data",
"tv::Tensor")
code.arg("vsize", f"std::array<float, {self.ndim}>") code.arg("vsize", f"std::array<float, {self.ndim}>")
code.arg("grid_size, grid_stride", f"std::array<int, {self.ndim}>") code.arg("grid_size, grid_stride", f"std::array<int, {self.ndim}>")
code.arg("coors_range", f"std::array<float, {self.ndim * 2}>") code.arg("coors_range", f"std::array<float, {self.ndim * 2}>")
...@@ -527,13 +543,16 @@ class Point2VoxelCPU(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -527,13 +543,16 @@ class Point2VoxelCPU(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
self.add_dependency(TensorView) self.add_dependency(TensorView)
layout = TensorGeneric(ndim, True) layout = TensorGeneric(ndim, True)
self.add_param_class("layout_ns", layout, "Layout") self.add_param_class("layout_ns", layout, "Layout")
self.dtype = dtype self.dtype = dtype
self.ndim = ndim self.ndim = ndim
self.zyx = zyx self.zyx = zyx
self.p2v_c = Point2VoxelCommon(dtype, ndim, zyx) self.p2v_c = Point2VoxelCommon(dtype, ndim, zyx)
self.add_param_class("p2v_c", self.p2v_c, "Point2VoxelCommon") self.add_param_class("p2v_c", self.p2v_c, "Point2VoxelCommon")
self.add_pybind_member("densehashdata", "tv::Tensor", readwrite=False, pyanno="cumm.tensorview.Tensor") self.add_pybind_member("densehashdata",
"tv::Tensor",
readwrite=False,
pyanno="cumm.tensorview.Tensor")
self.add_pybind_member("voxels", "tv::Tensor", readwrite=False) self.add_pybind_member("voxels", "tv::Tensor", readwrite=False)
self.add_pybind_member("indices", "tv::Tensor", readwrite=False) self.add_pybind_member("indices", "tv::Tensor", readwrite=False)
...@@ -568,7 +587,6 @@ class Point2VoxelCPU(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -568,7 +587,6 @@ class Point2VoxelCPU(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
""") """)
return code.ret(self.p2v_c.calc_meta_ret) return code.ret(self.p2v_c.calc_meta_ret)
@pccm.pybind.mark @pccm.pybind.mark
@pccm.constructor @pccm.constructor
def ctor(self): def ctor(self):
...@@ -613,7 +631,7 @@ class Point2VoxelCPU(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): ...@@ -613,7 +631,7 @@ class Point2VoxelCPU(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
densehashdata_ptr[i] = -1; densehashdata_ptr[i] = -1;
}} }}
""") """)
return code return code
def point_to_voxel_static_template(self, mean: bool = False): def point_to_voxel_static_template(self, mean: bool = False):
code = pccm.FunctionCode() code = pccm.FunctionCode()
......
...@@ -4,13 +4,14 @@ from pathlib import Path ...@@ -4,13 +4,14 @@ from pathlib import Path
import numpy as np import numpy as np
import torch import torch
from spconv.pytorch import ops from spconv.pytorch import ops, functional
from spconv.pytorch.conv import (SparseConv2d, SparseConv3d, SparseConvTranspose2d, from spconv.pytorch.conv import (SparseConv2d, SparseConv3d,
SparseConvTranspose3d, SparseInverseConv2d, SparseConvTranspose2d, SparseConvTranspose3d,
SparseInverseConv3d, SubMConv2d, SubMConv3d) SparseInverseConv2d, SparseInverseConv3d,
SubMConv2d, SubMConv3d)
from spconv.pytorch.core import SparseConvTensor from spconv.pytorch.core import SparseConvTensor
from spconv.pytorch.identity import Identity from spconv.pytorch.identity import Identity
from spconv.pytorch.modules import SparseModule, SparseSequential from spconv.pytorch.modules import SparseModule, SparseSequential, assign_name_for_sparse_modules
from spconv.pytorch.ops import ConvAlgo from spconv.pytorch.ops import ConvAlgo
from spconv.pytorch.pool import SparseMaxPool2d, SparseMaxPool3d from spconv.pytorch.pool import SparseMaxPool2d, SparseMaxPool3d
from spconv.pytorch.tables import AddTable, ConcatTable, JoinTable from spconv.pytorch.tables import AddTable, ConcatTable, JoinTable
......
# Copyright 2021 Yan Yan # Copyright 2021 Yan Yan
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch import torch
try: try:
remove_plus = torch.__version__.find("+") remove_plus = torch.__version__.find("+")
remove_dotdev = torch.__version__.find(".dev") remove_dotdev = torch.__version__.find(".dev")
...@@ -26,4 +26,4 @@ try: ...@@ -26,4 +26,4 @@ try:
PYTORCH_VERSION = list(map(int, PYTORCH_VERSION.split("."))) PYTORCH_VERSION = list(map(int, PYTORCH_VERSION.split(".")))
except: except:
# for unknown errors, just set a version # for unknown errors, just set a version
PYTORCH_VERSION = [1, 8, 0] PYTORCH_VERSION = [1, 8, 0]
\ No newline at end of file
...@@ -24,12 +24,13 @@ from torch.nn.parameter import Parameter ...@@ -24,12 +24,13 @@ from torch.nn.parameter import Parameter
from spconv import pytorch as spconv from spconv import pytorch as spconv
from spconv.core import ConvAlgo from spconv.core import ConvAlgo
import spconv.pytorch.functional as Fsp from spconv.pytorch import functional as Fsp
from spconv.pytorch import ops from spconv.pytorch import ops
from spconv.cppconstants import CPU_ONLY_BUILD from spconv.cppconstants import CPU_ONLY_BUILD
from spconv.pytorch.core import IndiceData, SparseConvTensor, ImplicitGemmIndiceData from spconv.pytorch.core import IndiceData, SparseConvTensor, ImplicitGemmIndiceData
from spconv.pytorch.modules import SparseModule from spconv.pytorch.modules import SparseModule
from spconv.constants import FILTER_HWIO from spconv.constants import FILTER_HWIO
from spconv.utils import nullcontext
def _calculate_fan_in_and_fan_out_hwio(tensor, algo: ConvAlgo): def _calculate_fan_in_and_fan_out_hwio(tensor, algo: ConvAlgo):
...@@ -205,6 +206,7 @@ class SparseConvolution(SparseModule): ...@@ -205,6 +206,7 @@ class SparseConvolution(SparseModule):
self.dilation) self.dilation)
else: else:
out_spatial_shape = spatial_shape out_spatial_shape = spatial_shape
# print(self._sparse_unique_name, spatial_shape, out_spatial_shape)
# input.update_grid(out_spatial_shape) # input.update_grid(out_spatial_shape)
# t = time.time() # t = time.time()
out_tensor = input.shadow_copy() out_tensor = input.shadow_copy()
...@@ -247,158 +249,165 @@ class SparseConvolution(SparseModule): ...@@ -247,158 +249,165 @@ class SparseConvolution(SparseModule):
out_tensor = out_tensor.replace_feature(features) out_tensor = out_tensor.replace_feature(features)
return out_tensor return out_tensor
indice_dict = input.indice_dict.copy() indice_dict = input.indice_dict.copy()
algo = self.algo algo = self.algo
if self.indice_key is not None : if self.indice_key is not None:
datas = input.find_indice_pair(self.indice_key) datas = input.find_indice_pair(self.indice_key)
if datas is not None: if datas is not None:
msg = "due to limitation of pytorch, you must provide same algo to layers share same indice key." msg = "due to limitation of pytorch, you must provide same algo to layers share same indice key."
assert algo == datas.algo, msg assert algo == datas.algo, msg
# algo = datas.algo # algo = datas.algo
if algo == ConvAlgo.Native: profile_ctx = nullcontext()
datas = input.find_indice_pair(self.indice_key) if input._timer is not None and self._sparse_unique_name:
if datas is not None: profile_ctx = input._timer.namespace(self._sparse_unique_name)
assert isinstance(datas, IndiceData) with profile_ctx:
if self.inverse: if algo == ConvAlgo.Native:
assert datas is not None and self.indice_key is not None datas = input.find_indice_pair(self.indice_key)
assert datas.is_subm is False, "inverse conv can only be used with standard conv and pool ops." if datas is not None:
assert isinstance(datas, IndiceData)
outids = datas.indices if self.inverse:
indice_pairs = datas.indice_pairs assert datas is not None and self.indice_key is not None
indice_pair_num = datas.indice_pair_num assert datas.is_subm is False, "inverse conv can only be used with standard conv and pool ops."
out_spatial_shape = datas.out_spatial_shape
assert indice_pair_num.shape[0] == np.prod( outids = datas.indices
self.kernel_size
), "inverse conv must have same kernel size as its couple conv"
else:
if self.indice_key is not None and datas is not None:
outids = datas.out_indices
indice_pairs = datas.indice_pairs indice_pairs = datas.indice_pairs
indice_pair_num = datas.indice_pair_num indice_pair_num = datas.indice_pair_num
out_spatial_shape = datas.out_spatial_shape
assert indice_pair_num.shape[0] == np.prod(
self.kernel_size
), "inverse conv must have same kernel size as its couple conv"
else: else:
if input.benchmark: if self.indice_key is not None and datas is not None:
torch.cuda.synchronize() outids = datas.out_indices
t = time.time() indice_pairs = datas.indice_pairs
outids, indice_pairs, indice_pair_num = ops.get_indice_pairs( indice_pair_num = datas.indice_pair_num
indices, batch_size, spatial_shape, algo, else:
self.kernel_size, self.stride, self.padding, if input.benchmark:
self.dilation, self.output_padding, self.subm, torch.cuda.synchronize()
self.transposed) t = time.time()
if input.benchmark: outids, indice_pairs, indice_pair_num = ops.get_indice_pairs(
torch.cuda.synchronize() indices, batch_size, spatial_shape, algo,
interval = time.time() - t self.kernel_size, self.stride, self.padding,
out_tensor.benchmark_record[ self.dilation, self.output_padding, self.subm,
self.name]["indice_gen_time"].append(interval) self.transposed)
if input.benchmark:
indice_data = IndiceData(outids, torch.cuda.synchronize()
indices, interval = time.time() - t
indice_pairs, out_tensor.benchmark_record[
indice_pair_num, self.name]["indice_gen_time"].append(interval)
spatial_shape,
is_subm=self.subm, indice_data = IndiceData(outids,
algo=algo) indices,
if self.indice_key is not None: indice_pairs,
msg = f"your indice key {self.indice_key} already exists in this sparse tensor." indice_pair_num,
assert self.indice_key not in indice_dict, msg spatial_shape,
indice_dict[self.indice_key] = indice_data is_subm=self.subm,
if input.benchmark: algo=algo)
torch.cuda.synchronize() if self.indice_key is not None:
t = time.time() msg = f"your indice key {self.indice_key} already exists in this sparse tensor."
indice_pairs_calc = indice_pairs assert self.indice_key not in indice_dict, msg
if indice_pairs.device != features.device: indice_dict[self.indice_key] = indice_data
indice_pairs_calc = indice_pairs.to(features.device) if input.benchmark:
if self.subm: torch.cuda.synchronize()
out_features = Fsp.indice_subm_conv(features, self.weight, t = time.time()
indice_pairs_calc, indice_pairs_calc = indice_pairs
indice_pair_num, if indice_pairs.device != features.device:
outids.shape[0], algo) indice_pairs_calc = indice_pairs.to(features.device)
else: if self.subm:
if self.inverse: out_features = Fsp.indice_subm_conv(
out_features = Fsp.indice_inverse_conv(
features, self.weight, indice_pairs_calc, features, self.weight, indice_pairs_calc,
indice_pair_num, outids.shape[0], algo) indice_pair_num, outids.shape[0], algo, input._timer)
else: else:
out_features = Fsp.indice_conv(features, self.weight, if self.inverse:
indice_pairs_calc, out_features = Fsp.indice_inverse_conv(
indice_pair_num, features, self.weight, indice_pairs_calc,
outids.shape[0], algo) indice_pair_num, outids.shape[0], algo)
else:
else: out_features = Fsp.indice_conv(features, self.weight,
datas = input.find_indice_pair(self.indice_key) indice_pairs_calc,
if datas is not None: indice_pair_num,
assert isinstance(datas, ImplicitGemmIndiceData) outids.shape[0], algo,
if self.inverse: input._timer)
assert datas is not None and self.indice_key is not None
assert datas.is_subm is False, "inverse conv can only be used with standard conv and pool ops."
outids = datas.indices
pair_fwd = datas.pair_bwd
pair_bwd = datas.pair_fwd
pair_mask_fwd_splits = datas.pair_mask_bwd_splits
pair_mask_bwd_splits = datas.pair_mask_fwd_splits
mask_argsort_fwd_splits = datas.mask_argsort_bwd_splits
mask_argsort_bwd_splits = datas.mask_argsort_fwd_splits
masks = datas.masks
else: else:
if self.indice_key is not None and datas is not None: datas = input.find_indice_pair(self.indice_key)
outids = datas.out_indices if datas is not None:
pair_fwd = datas.pair_fwd assert isinstance(datas, ImplicitGemmIndiceData)
pair_bwd = datas.pair_bwd if self.inverse:
pair_mask_fwd_splits = datas.pair_mask_fwd_splits assert datas is not None and self.indice_key is not None
pair_mask_bwd_splits = datas.pair_mask_bwd_splits assert datas.is_subm is False, "inverse conv can only be used with standard conv and pool ops."
mask_argsort_fwd_splits = datas.mask_argsort_fwd_splits outids = datas.indices
mask_argsort_bwd_splits = datas.mask_argsort_bwd_splits pair_fwd = datas.pair_bwd
pair_bwd = datas.pair_fwd
pair_mask_fwd_splits = datas.pair_mask_bwd_splits
pair_mask_bwd_splits = datas.pair_mask_fwd_splits
mask_argsort_fwd_splits = datas.mask_argsort_bwd_splits
mask_argsort_bwd_splits = datas.mask_argsort_fwd_splits
masks = datas.masks masks = datas.masks
else: else:
res = ops.get_indice_pairs_implicit_gemm( if self.indice_key is not None and datas is not None:
indices, outids = datas.out_indices
batch_size, pair_fwd = datas.pair_fwd
spatial_shape, pair_bwd = datas.pair_bwd
algo, pair_mask_fwd_splits = datas.pair_mask_fwd_splits
ksize=self.kernel_size, pair_mask_bwd_splits = datas.pair_mask_bwd_splits
stride=self.stride, mask_argsort_fwd_splits = datas.mask_argsort_fwd_splits
padding=self.padding, mask_argsort_bwd_splits = datas.mask_argsort_bwd_splits
dilation=self.dilation, masks = datas.masks
out_padding=self.output_padding, else:
subm=self.subm, with input._timer.namespace("gen_pairs"):
transpose=self.transposed, res = ops.get_indice_pairs_implicit_gemm(
is_train=self.training, indices,
alloc=input.thrust_allocator) batch_size,
outids = res[0] spatial_shape,
num_inds_per_loc = res[1] algo,
pair_fwd = res[2] ksize=self.kernel_size,
pair_bwd = res[3] stride=self.stride,
pair_mask_fwd_splits = res[4] padding=self.padding,
pair_mask_bwd_splits = res[5] dilation=self.dilation,
mask_argsort_fwd_splits = res[6] out_padding=self.output_padding,
mask_argsort_bwd_splits = res[7] subm=self.subm,
masks = res[8] transpose=self.transposed,
if self.indice_key is not None: is_train=self.training,
indice_data = ImplicitGemmIndiceData( alloc=input.thrust_allocator,
outids, timer=input._timer)
indices, outids = res[0]
pair_fwd, num_inds_per_loc = res[1]
pair_bwd, pair_fwd = res[2]
pair_mask_fwd_splits=pair_mask_fwd_splits, pair_bwd = res[3]
pair_mask_bwd_splits=pair_mask_bwd_splits, pair_mask_fwd_splits = res[4]
mask_argsort_fwd_splits=mask_argsort_fwd_splits, pair_mask_bwd_splits = res[5]
mask_argsort_bwd_splits=mask_argsort_bwd_splits, mask_argsort_fwd_splits = res[6]
masks=masks, mask_argsort_bwd_splits = res[7]
is_subm=self.subm, masks = res[8]
out_spatial_shape=out_spatial_shape, if self.indice_key is not None:
algo=algo) indice_data = ImplicitGemmIndiceData(
msg = f"your indice key {self.indice_key} already exists in this sparse tensor." outids,
assert self.indice_key not in indice_dict, msg indices,
indice_dict[self.indice_key] = indice_data pair_fwd,
if input.benchmark: pair_bwd,
torch.cuda.synchronize() pair_mask_fwd_splits=pair_mask_fwd_splits,
t = time.time() pair_mask_bwd_splits=pair_mask_bwd_splits,
num_activate_out = outids.shape[0] mask_argsort_fwd_splits=mask_argsort_fwd_splits,
out_features = Fsp.implicit_gemm( mask_argsort_bwd_splits=mask_argsort_bwd_splits,
features, self.weight, pair_fwd, pair_bwd, masks=masks,
pair_mask_fwd_splits, pair_mask_bwd_splits, is_subm=self.subm,
mask_argsort_fwd_splits, mask_argsort_bwd_splits, out_spatial_shape=out_spatial_shape,
num_activate_out, masks, self.training, self.subm) algo=algo)
msg = f"your indice key {self.indice_key} already exists in this sparse tensor."
assert self.indice_key not in indice_dict, msg
indice_dict[self.indice_key] = indice_data
if input.benchmark:
torch.cuda.synchronize()
t = time.time()
num_activate_out = outids.shape[0]
out_features = Fsp.implicit_gemm(
features, self.weight, pair_fwd, pair_bwd,
pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits,
num_activate_out, masks, self.training, self.subm,
input._timer)
if self.bias is not None: if self.bias is not None:
out_features += self.bias out_features += self.bias
if input.benchmark: if input.benchmark:
......
...@@ -19,6 +19,7 @@ import torch ...@@ -19,6 +19,7 @@ import torch
from spconv.core import ConvAlgo from spconv.core import ConvAlgo
from spconv.pytorch.constants import PYTORCH_VERSION from spconv.pytorch.constants import PYTORCH_VERSION
from spconv.pytorch.ops import ThrustSortAllocator from spconv.pytorch.ops import ThrustSortAllocator
from spconv.tools import CUDAKernelTimer
if PYTORCH_VERSION >= [1, 8, 0]: if PYTORCH_VERSION >= [1, 8, 0]:
try: try:
...@@ -51,13 +52,14 @@ class IndiceData(object): ...@@ -51,13 +52,14 @@ class IndiceData(object):
class ImplicitGemmIndiceData(object): class ImplicitGemmIndiceData(object):
def __init__(self, out_indices: torch.Tensor, indices: torch.Tensor, pair_fwd: torch.Tensor, def __init__(self, out_indices: torch.Tensor, indices: torch.Tensor,
pair_bwd: torch.Tensor, pair_fwd: torch.Tensor, pair_bwd: torch.Tensor,
pair_mask_fwd_splits: List[torch.Tensor], pair_mask_fwd_splits: List[torch.Tensor],
pair_mask_bwd_splits: List[torch.Tensor], pair_mask_bwd_splits: List[torch.Tensor],
mask_argsort_fwd_splits: List[torch.Tensor], mask_argsort_fwd_splits: List[torch.Tensor],
mask_argsort_bwd_splits: List[torch.Tensor], mask_argsort_bwd_splits: List[torch.Tensor],
masks: List[np.ndarray], out_spatial_shape, is_subm: bool, algo: ConvAlgo): masks: List[np.ndarray], out_spatial_shape, is_subm: bool,
algo: ConvAlgo):
self.out_indices = out_indices self.out_indices = out_indices
self.indices = indices self.indices = indices
self.pair_fwd = pair_fwd self.pair_fwd = pair_fwd
...@@ -99,7 +101,8 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -99,7 +101,8 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
voxel_num: Optional[torch.Tensor] = None, voxel_num: Optional[torch.Tensor] = None,
indice_dict: Optional[dict] = None, indice_dict: Optional[dict] = None,
benchmark: bool = False, benchmark: bool = False,
permanent_thrust_allocator: bool = False): permanent_thrust_allocator: bool = False,
enable_timer: bool = False):
""" """
Args: Args:
features: [num_points, num_features] feature tensor features: [num_points, num_features] feature tensor
...@@ -130,9 +133,10 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -130,9 +133,10 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
self.voxel_num = voxel_num # for tensorrt self.voxel_num = voxel_num # for tensorrt
self.benchmark = benchmark self.benchmark = benchmark
self.benchmark_record = {} self.benchmark_record = {}
self.thrust_allocator: Optional[ThrustSortAllocator] = None self.thrust_allocator: Optional[ThrustSortAllocator] = None
if permanent_thrust_allocator: if permanent_thrust_allocator:
self.thrust_allocator = ThrustSortAllocator(features.device) self.thrust_allocator = ThrustSortAllocator(features.device)
self._timer = CUDAKernelTimer(enable_timer)
def replace_feature(self, feature): def replace_feature(self, feature):
"""we need to replace x.features = F.relu(x.features) with x = x.replace_feature(F.relu(x.features)) """we need to replace x.features = F.relu(x.features) with x = x.replace_feature(F.relu(x.features))
...@@ -144,7 +148,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -144,7 +148,7 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
new_spt.benchmark = self.benchmark new_spt.benchmark = self.benchmark
new_spt.benchmark_record = self.benchmark_record new_spt.benchmark_record = self.benchmark_record
new_spt.thrust_allocator = self.thrust_allocator new_spt.thrust_allocator = self.thrust_allocator
new_spt._timer = self._timer
return new_spt return new_spt
@property @property
...@@ -174,7 +178,8 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -174,7 +178,8 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
def spatial_size(self): def spatial_size(self):
return np.prod(self.spatial_shape) return np.prod(self.spatial_shape)
def find_indice_pair(self, key) -> Optional[Union[IndiceData, ImplicitGemmIndiceData]]: def find_indice_pair(
self, key) -> Optional[Union[IndiceData, ImplicitGemmIndiceData]]:
if key is None: if key is None:
return None return None
if key in self.indice_dict: if key in self.indice_dict:
...@@ -208,4 +213,5 @@ class SparseConvTensor(metaclass=SpConvTensorMeta): ...@@ -208,4 +213,5 @@ class SparseConvTensor(metaclass=SpConvTensorMeta):
self.benchmark) self.benchmark)
tensor.benchmark_record = self.benchmark_record tensor.benchmark_record = self.benchmark_record
tensor.thrust_allocator = self.thrust_allocator tensor.thrust_allocator = self.thrust_allocator
tensor._timer = self._timer
return tensor return tensor
# Copyright 2021 Yan Yan # Copyright 2021 Yan Yan
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from cumm import tensorview as tv from cumm import tensorview as tv
import torch import torch
from typing import Optional, List from typing import Optional, List
_TORCH_DTYPE_TO_TV = { _TORCH_DTYPE_TO_TV = {
torch.float32: tv.float32, torch.float32: tv.float32,
torch.float64: tv.float64, torch.float64: tv.float64,
...@@ -26,10 +27,13 @@ _TORCH_DTYPE_TO_TV = { ...@@ -26,10 +27,13 @@ _TORCH_DTYPE_TO_TV = {
torch.uint8: tv.uint8, torch.uint8: tv.uint8,
} }
def torch_tensor_to_tv(ten: torch.Tensor, dtype: Optional[int] = None, shape: Optional[List[int]] = None):
def torch_tensor_to_tv(ten: torch.Tensor,
dtype: Optional[int] = None,
shape: Optional[List[int]] = None):
assert ten.is_contiguous(), "must be contiguous tensor" assert ten.is_contiguous(), "must be contiguous tensor"
ptr = ten.data_ptr() ptr = ten.data_ptr()
device = ten.device device = ten.device
if device.type == "cpu": if device.type == "cpu":
tv_device = -1 tv_device = -1
elif device.type == "cuda": elif device.type == "cuda":
...@@ -42,10 +46,12 @@ def torch_tensor_to_tv(ten: torch.Tensor, dtype: Optional[int] = None, shape: Op ...@@ -42,10 +46,12 @@ def torch_tensor_to_tv(ten: torch.Tensor, dtype: Optional[int] = None, shape: Op
dtype = _TORCH_DTYPE_TO_TV[ten.dtype] dtype = _TORCH_DTYPE_TO_TV[ten.dtype]
return tv.from_blob(ptr, shape, dtype, tv_device) return tv.from_blob(ptr, shape, dtype, tv_device)
def get_current_stream(): def get_current_stream():
return torch.cuda.current_stream().cuda_stream return torch.cuda.current_stream().cuda_stream
if __name__ == "__main__": if __name__ == "__main__":
a = torch.rand(2, 2) a = torch.rand(2, 2)
atv = torch_tensor_to_tv(a) atv = torch_tensor_to_tv(a)
print(atv.numpy_view()) print(atv.numpy_view())
\ No newline at end of file
...@@ -15,8 +15,9 @@ ...@@ -15,8 +15,9 @@
import torch import torch
from torch import nn from torch import nn
from torch.autograd import Function from torch.autograd import Function
from typing import Optional
import spconv.pytorch.ops as ops from spconv.tools import CUDAKernelTimer
from spconv.pytorch import ops
import torch.cuda.amp as amp import torch.cuda.amp as amp
from torch.autograd.function import once_differentiable from torch.autograd.function import once_differentiable
import numpy as np import numpy as np
...@@ -27,23 +28,32 @@ from typing import List ...@@ -27,23 +28,32 @@ from typing import List
class SparseConvFunction(Function): class SparseConvFunction(Function):
@staticmethod @staticmethod
@amp.custom_fwd(cast_inputs=torch.float16) @amp.custom_fwd(cast_inputs=torch.float16)
def forward(ctx, features, filters, indice_pairs, indice_pair_num, def forward(ctx,
num_activate_out, algo): features,
filters,
indice_pairs,
indice_pair_num,
num_activate_out,
algo,
timer: CUDAKernelTimer = CUDAKernelTimer(False)):
ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters) ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters)
ctx.algo = algo ctx.algo = algo
ctx.timer = timer
return ops.indice_conv(features, return ops.indice_conv(features,
filters, filters,
indice_pairs, indice_pairs,
indice_pair_num, indice_pair_num,
num_activate_out, num_activate_out,
False, False,
algo=algo) algo=algo,
timer=timer)
@staticmethod @staticmethod
@once_differentiable @once_differentiable
@amp.custom_bwd @amp.custom_bwd
def backward(ctx, grad_output): def backward(ctx, grad_output):
indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
timer = ctx.timer
input_bp, filters_bp = ops.indice_conv_backward(features, input_bp, filters_bp = ops.indice_conv_backward(features,
filters, filters,
...@@ -51,18 +61,27 @@ class SparseConvFunction(Function): ...@@ -51,18 +61,27 @@ class SparseConvFunction(Function):
indice_pairs, indice_pairs,
indice_pair_num, indice_pair_num,
False, False,
algo=ctx.algo) algo=ctx.algo,
timer=timer)
return input_bp, filters_bp, None, None, None, None return input_bp, filters_bp, None, None, None, None, None
class SparseInverseConvFunction(Function): class SparseInverseConvFunction(Function):
@staticmethod @staticmethod
@amp.custom_fwd(cast_inputs=torch.float16) @amp.custom_fwd(cast_inputs=torch.float16)
def forward(ctx, features, filters, indice_pairs, indice_pair_num, def forward(ctx,
num_activate_out, algo): features,
filters,
indice_pairs,
indice_pair_num,
num_activate_out,
algo,
timer: CUDAKernelTimer = CUDAKernelTimer(False)):
ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters) ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters)
ctx.algo = algo ctx.algo = algo
ctx.timer = timer
return ops.indice_conv(features, return ops.indice_conv(features,
filters, filters,
indice_pairs, indice_pairs,
...@@ -70,13 +89,16 @@ class SparseInverseConvFunction(Function): ...@@ -70,13 +89,16 @@ class SparseInverseConvFunction(Function):
num_activate_out, num_activate_out,
True, True,
False, False,
algo=algo) algo=algo,
timer=timer)
@staticmethod @staticmethod
@once_differentiable @once_differentiable
@amp.custom_bwd @amp.custom_bwd
def backward(ctx, grad_output): def backward(ctx, grad_output):
indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
timer = ctx.timer
input_bp, filters_bp = ops.indice_conv_backward(features, input_bp, filters_bp = ops.indice_conv_backward(features,
filters, filters,
grad_output, grad_output,
...@@ -84,29 +106,40 @@ class SparseInverseConvFunction(Function): ...@@ -84,29 +106,40 @@ class SparseInverseConvFunction(Function):
indice_pair_num, indice_pair_num,
True, True,
False, False,
algo=ctx.algo) algo=ctx.algo,
timer=timer)
return input_bp, filters_bp, None, None, None, None return input_bp, filters_bp, None, None, None, None, None
class SparseImplicitGemmFunction(Function): class SparseImplicitGemmFunction(Function):
@staticmethod @staticmethod
@amp.custom_fwd(cast_inputs=torch.float16) @amp.custom_fwd(cast_inputs=torch.float16)
def forward(ctx, features: torch.Tensor, filters: torch.Tensor, def forward(ctx,
pair_fwd: torch.Tensor, pair_bwd: torch.Tensor, features: torch.Tensor,
filters: torch.Tensor,
pair_fwd: torch.Tensor,
pair_bwd: torch.Tensor,
pair_mask_fwd_splits: List[torch.Tensor], pair_mask_fwd_splits: List[torch.Tensor],
pair_mask_bwd_splits: List[torch.Tensor], pair_mask_bwd_splits: List[torch.Tensor],
mask_argsort_fwd_splits: List[torch.Tensor], mask_argsort_fwd_splits: List[torch.Tensor],
mask_argsort_bwd_splits: List[torch.Tensor], mask_argsort_bwd_splits: List[torch.Tensor],
num_activate_out: int, masks: List[np.ndarray], is_train: bool, num_activate_out: int,
is_subm: bool): masks: List[np.ndarray],
is_train: bool,
is_subm: bool,
timer: CUDAKernelTimer = CUDAKernelTimer(False)):
out, mask_out, mask_width = ops.implicit_gemm( out, mask_out, mask_width = ops.implicit_gemm(features, filters,
features, filters, pair_fwd, pair_mask_fwd_splits, pair_fwd,
mask_argsort_fwd_splits, num_activate_out, masks, is_train, is_subm) pair_mask_fwd_splits,
mask_argsort_fwd_splits,
num_activate_out, masks,
is_train, is_subm, timer)
ctx.save_for_backward(features, filters, pair_fwd, pair_bwd) ctx.save_for_backward(features, filters, pair_fwd, pair_bwd)
ctx.mask_width = mask_width ctx.mask_width = mask_width
ctx.mask_out = mask_out ctx.mask_out = mask_out
ctx.timer = timer
ctx.pair_mask_fwd_splits = pair_mask_fwd_splits ctx.pair_mask_fwd_splits = pair_mask_fwd_splits
ctx.mask_argsort_fwd_splits = mask_argsort_fwd_splits ctx.mask_argsort_fwd_splits = mask_argsort_fwd_splits
ctx.pair_mask_bwd_splits = pair_mask_bwd_splits ctx.pair_mask_bwd_splits = pair_mask_bwd_splits
...@@ -130,30 +163,40 @@ class SparseImplicitGemmFunction(Function): ...@@ -130,30 +163,40 @@ class SparseImplicitGemmFunction(Function):
# num_activate_out = ctx.num_activate_out # num_activate_out = ctx.num_activate_out
masks = ctx.masks masks = ctx.masks
is_subm = ctx.is_subm is_subm = ctx.is_subm
timer = ctx.timer
input_bp, filters_bp = ops.implicit_gemm_backward(features, input_bp, filters_bp = ops.implicit_gemm_backward(
filters, features,
grad_output, filters,
pair_fwd, grad_output,
pair_bwd, pair_fwd,
pair_mask_fwd_splits, pair_bwd,
pair_mask_bwd_splits, pair_mask_fwd_splits,
mask_argsort_fwd_splits, pair_mask_bwd_splits,
mask_argsort_bwd_splits, mask_argsort_fwd_splits,
mask_output_fwd=mask_out, mask_argsort_bwd_splits,
masks=masks, mask_output_fwd=mask_out,
mask_width=mask_width, masks=masks,
is_subm=is_subm) mask_width=mask_width,
None_9 = [None] * 10 is_subm=is_subm,
timer=timer)
None_9 = [None] * 11
return (input_bp, filters_bp, *None_9) return (input_bp, filters_bp, *None_9)
class SubMConvFunction(Function): class SubMConvFunction(Function):
@staticmethod @staticmethod
@amp.custom_fwd(cast_inputs=torch.float16) @amp.custom_fwd(cast_inputs=torch.float16)
def forward(ctx, features, filters, indice_pairs, indice_pair_num, def forward(ctx,
num_activate_out, algo): features,
filters,
indice_pairs,
indice_pair_num,
num_activate_out,
algo,
timer: CUDAKernelTimer = CUDAKernelTimer(False)):
ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters) ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters)
ctx.algo = algo ctx.algo = algo
ctx.timer = timer
return ops.indice_conv(features, return ops.indice_conv(features,
filters, filters,
indice_pairs, indice_pairs,
...@@ -161,13 +204,16 @@ class SubMConvFunction(Function): ...@@ -161,13 +204,16 @@ class SubMConvFunction(Function):
num_activate_out, num_activate_out,
False, False,
True, True,
algo=algo) algo=algo,
timer=timer)
@staticmethod @staticmethod
@once_differentiable @once_differentiable
@amp.custom_bwd @amp.custom_bwd
def backward(ctx, grad_output): def backward(ctx, grad_output):
indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
timer = ctx.timer
input_bp, filters_bp = ops.indice_conv_backward(features, input_bp, filters_bp = ops.indice_conv_backward(features,
filters, filters,
grad_output, grad_output,
...@@ -175,9 +221,10 @@ class SubMConvFunction(Function): ...@@ -175,9 +221,10 @@ class SubMConvFunction(Function):
indice_pair_num, indice_pair_num,
False, False,
True, True,
algo=ctx.algo) algo=ctx.algo,
timer=timer)
return input_bp, filters_bp, None, None, None, None return input_bp, filters_bp, None, None, None, None, None
class SparseMaxPoolFunction(Function): class SparseMaxPoolFunction(Function):
...@@ -199,12 +246,14 @@ class SparseMaxPoolFunction(Function): ...@@ -199,12 +246,14 @@ class SparseMaxPoolFunction(Function):
indice_pairs, indice_pair_num) indice_pairs, indice_pair_num)
return input_bp, None, None, None return input_bp, None, None, None
class SparseMaxPoolImplicitGemmFunction(Function): class SparseMaxPoolImplicitGemmFunction(Function):
@staticmethod @staticmethod
@amp.custom_fwd(cast_inputs=torch.float16) @amp.custom_fwd(cast_inputs=torch.float16)
def forward(ctx, features: torch.Tensor, indice_pairs_fwd: torch.Tensor, indice_pairs_bwd: torch.Tensor, def forward(ctx, features: torch.Tensor, indice_pairs_fwd: torch.Tensor,
num_activate_out: int): indice_pairs_bwd: torch.Tensor, num_activate_out: int):
out = ops.indice_maxpool_implicit_gemm(features, indice_pairs_fwd, num_activate_out) out = ops.indice_maxpool_implicit_gemm(features, indice_pairs_fwd,
num_activate_out)
ctx.save_for_backward(indice_pairs_bwd, features, out) ctx.save_for_backward(indice_pairs_bwd, features, out)
return out return out
...@@ -213,10 +262,11 @@ class SparseMaxPoolImplicitGemmFunction(Function): ...@@ -213,10 +262,11 @@ class SparseMaxPoolImplicitGemmFunction(Function):
@amp.custom_bwd @amp.custom_bwd
def backward(ctx, grad_output): def backward(ctx, grad_output):
indice_pairs_bwd, features, out = ctx.saved_tensors indice_pairs_bwd, features, out = ctx.saved_tensors
input_bp = ops.indice_maxpool_implicit_gemm_backward(features, out, grad_output, input_bp = ops.indice_maxpool_implicit_gemm_backward(
indice_pairs_bwd) features, out, grad_output, indice_pairs_bwd)
return input_bp, None, None, None return input_bp, None, None, None
indice_conv = SparseConvFunction.apply indice_conv = SparseConvFunction.apply
implicit_gemm = SparseImplicitGemmFunction.apply implicit_gemm = SparseImplicitGemmFunction.apply
indice_inverse_conv = SparseInverseConvFunction.apply indice_inverse_conv = SparseInverseConvFunction.apply
......
# Copyright 2021 Yan Yan # Copyright 2021 Yan Yan
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys import sys
import time import time
from collections import OrderedDict from collections import OrderedDict
...@@ -53,6 +52,7 @@ class SparseModule(nn.Module): ...@@ -53,6 +52,7 @@ class SparseModule(nn.Module):
def __init__(self, name=None): def __init__(self, name=None):
super().__init__() super().__init__()
self.name = name self.name = name
self._sparse_unique_name = ""
class SparseSequential(SparseModule): class SparseSequential(SparseModule):
...@@ -143,3 +143,8 @@ class SparseSequential(SparseModule): ...@@ -143,3 +143,8 @@ class SparseSequential(SparseModule):
input = module(input) input = module(input)
return input return input
def assign_name_for_sparse_modules(module: nn.Module):
for k, n in module.named_modules():
if isinstance(n, SparseModule):
n._sparse_unique_name = k
This diff is collapsed.
This diff is collapsed.
# Copyright 2021 Yan Yan # Copyright 2021 Yan Yan
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......
# Copyright 2021 Yan Yan # Copyright 2021 Yan Yan
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -15,18 +15,18 @@ ...@@ -15,18 +15,18 @@
import torch import torch
from torch.autograd import Function from torch.autograd import Function
import spconv.pytorch as spconv
#from torch.nn import Module #from torch.nn import Module
from spconv.pytorch.modules import SparseModule from spconv.pytorch.modules import SparseModule
from spconv.pytorch.core import SparseConvTensor from spconv.pytorch.core import SparseConvTensor
from typing import List from typing import List
class JoinTable(SparseModule): # Module): class JoinTable(SparseModule): # Module):
def forward(self, input: List[SparseConvTensor]): def forward(self, input: List[SparseConvTensor]):
output = spconv.SparseConvTensor( output = SparseConvTensor(torch.cat([i.features for i in input], 1),
torch.cat([i.features for i in input], 1), input[0].indices, input[0].indices, input[0].spatial_shape,
input[0].spatial_shape, input[0].batch_size, input[0].grid, input[0].voxel_num, input[0].batch_size, input[0].grid,
input[0].indice_dict) input[0].voxel_num, input[0].indice_dict)
output.benchmark_record = input[1].benchmark_record output.benchmark_record = input[1].benchmark_record
output.thrust_allocator = input[1].thrust_allocator output.thrust_allocator = input[1].thrust_allocator
return output return output
...@@ -37,10 +37,10 @@ class JoinTable(SparseModule): # Module): ...@@ -37,10 +37,10 @@ class JoinTable(SparseModule): # Module):
class AddTable(SparseModule): # Module): class AddTable(SparseModule): # Module):
def forward(self, input: List[SparseConvTensor]): def forward(self, input: List[SparseConvTensor]):
output = spconv.SparseConvTensor( output = SparseConvTensor(sum([i.features for i in input]),
sum([i.features for i in input]), input[0].indices, input[0].indices, input[0].spatial_shape,
input[0].spatial_shape, input[0].batch_size, input[0].grid, input[0].voxel_num, input[0].batch_size, input[0].grid,
input[0].indice_dict) input[0].voxel_num, input[0].indice_dict)
output.benchmark_record = input[1].benchmark_record output.benchmark_record = input[1].benchmark_record
output.thrust_allocator = input[1].thrust_allocator output.thrust_allocator = input[1].thrust_allocator
return output return output
......
...@@ -82,24 +82,25 @@ class PointToVoxel(object): ...@@ -82,24 +82,25 @@ class PointToVoxel(object):
if self.point_indice_data.shape[0] < pc.shape[0]: if self.point_indice_data.shape[0] < pc.shape[0]:
self.point_indice_data = torch.empty([pc.shape[0]], self.point_indice_data = torch.empty([pc.shape[0]],
dtype=torch.int64, dtype=torch.int64,
device=self.device) device=self.device)
pc_tv = torch_tensor_to_tv(pc) pc_tv = torch_tensor_to_tv(pc)
stream = get_current_stream() stream = get_current_stream()
voxels_tv = torch_tensor_to_tv(self.voxels) voxels_tv = torch_tensor_to_tv(self.voxels)
indices_tv = torch_tensor_to_tv(self.indices) indices_tv = torch_tensor_to_tv(self.indices)
num_per_voxel_tv = torch_tensor_to_tv(self.num_per_voxel) num_per_voxel_tv = torch_tensor_to_tv(self.num_per_voxel)
hashdata_tv = torch_tensor_to_tv(self.hashdata, hashdata_tv = torch_tensor_to_tv(
dtype=tv.custom128, self.hashdata,
shape=[self.hashdata.shape[0]]) dtype=tv.custom128,
point_indice_data_tv = torch_tensor_to_tv(self.point_indice_data) shape=[self.hashdata.shape[0]])
point_indice_data_tv = torch_tensor_to_tv(
self.point_indice_data)
res = SpconvOps.point2voxel_cuda(pc_tv, voxels_tv, indices_tv, res = SpconvOps.point2voxel_cuda(
num_per_voxel_tv, hashdata_tv, pc_tv, voxels_tv, indices_tv, num_per_voxel_tv,
point_indice_data_tv, self.vsize, hashdata_tv, point_indice_data_tv, self.vsize,
self.grid_size, self.grid_stride, self.grid_size, self.grid_stride, self.coors_range,
self.coors_range, empty_mean, empty_mean, clear_voxels, stream)
clear_voxels, stream)
num_voxels = res[0].shape[0] num_voxels = res[0].shape[0]
else: else:
pc_tv = torch_tensor_to_tv(pc) pc_tv = torch_tensor_to_tv(pc)
...@@ -111,8 +112,9 @@ class PointToVoxel(object): ...@@ -111,8 +112,9 @@ class PointToVoxel(object):
res = SpconvOps.point2voxel_cpu(pc_tv, voxels_tv, indices_tv, res = SpconvOps.point2voxel_cpu(pc_tv, voxels_tv, indices_tv,
num_per_voxel_tv, hashdata_tv, num_per_voxel_tv, hashdata_tv,
self.vsize, self.grid_size, self.vsize, self.grid_size,
self.grid_stride, self.coors_range, self.grid_stride,
empty_mean, clear_voxels) self.coors_range, empty_mean,
clear_voxels)
num_voxels = res[0].shape[0] num_voxels = res[0].shape[0]
return (self.voxels[:num_voxels], self.indices[:num_voxels], return (self.voxels[:num_voxels], self.indices[:num_voxels],
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
2.1.3 2.1.5
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment