Commit 52594038 authored by yan.yan's avatar yan.yan
Browse files

v2.1.21: add sm37, avoid fp16 nan

parent b0f52b8a
# Changelog # Changelog
## [2.1.21] - 2021-12-9
### Added
- add sm_37
- add fp16 kernels witl fp32 accumulator (run slower, but can avoid nan if channel size is too large)
## [2.1.20] - 2021-12-6 ## [2.1.20] - 2021-12-6
### Added ### Added
- Add fp16 conv simt kernels for mixed-training in pascal or older GPUS. WARNING: not optimized for TESLA P100 which has 2x throughput in half. - Add fp16 conv simt kernels for mixed-training in pascal or older GPUS. WARNING: not optimized for TESLA P100 which has 2x throughput in half.
......
...@@ -38,9 +38,9 @@ if cuda_ver: ...@@ -38,9 +38,9 @@ if cuda_ver:
cuda_ver = cuda_ver.replace(".", "") # 10.2 to 102 cuda_ver = cuda_ver.replace(".", "") # 10.2 to 102
RELEASE_NAME += "-cu{}".format(cuda_ver) RELEASE_NAME += "-cu{}".format(cuda_ver)
deps = ["cumm-cu{}>=0.2.6".format(cuda_ver)] deps = ["cumm-cu{}>=0.2.8".format(cuda_ver)]
else: else:
deps = ["cumm>=0.2.6"] deps = ["cumm>=0.2.8"]
...@@ -158,6 +158,7 @@ if disable_jit is not None and disable_jit == "1": ...@@ -158,6 +158,7 @@ if disable_jit is not None and disable_jit == "1":
from spconv.csrc.sparse.all import SpconvOps from spconv.csrc.sparse.all import SpconvOps
from spconv.csrc.utils import BoxOps from spconv.csrc.utils import BoxOps
from spconv.csrc.hash.core import HashTable from spconv.csrc.hash.core import HashTable
from cumm.common import CompileInfo
cu = GemmMainUnitTest(SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS) cu = GemmMainUnitTest(SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS)
convcu = ConvMainUnitTest(IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS + IMPLGEMM_TURING_PARAMS) convcu = ConvMainUnitTest(IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS + IMPLGEMM_TURING_PARAMS)
...@@ -171,9 +172,9 @@ if disable_jit is not None and disable_jit == "1": ...@@ -171,9 +172,9 @@ if disable_jit is not None and disable_jit == "1":
std = "c++14" std = "c++14"
else: else:
std = "c++17" std = "c++17"
cus = [cu, convcu, SpconvOps(), BoxOps(), HashTable()] cus = [cu, convcu, SpconvOps(), BoxOps(), HashTable(), CompileInfo()]
if CUMM_CPU_ONLY_BUILD: if CUMM_CPU_ONLY_BUILD:
cus = [SpconvOps(), BoxOps(), HashTable()] cus = [SpconvOps(), BoxOps(), HashTable(), CompileInfo()]
ext_modules: List[Extension] = [ ext_modules: List[Extension] = [
PCCMExtension(cus, PCCMExtension(cus,
"spconv/core_cc", "spconv/core_cc",
......
...@@ -77,12 +77,12 @@ class SimpleGemm: ...@@ -77,12 +77,12 @@ class SimpleGemm:
if tile_key not in tile_shape_to_algos: if tile_key not in tile_shape_to_algos:
tile_shape_to_algos[tile_key] = [] tile_shape_to_algos[tile_key] = []
tile_shape_to_algos[tile_key].append(i) tile_shape_to_algos[tile_key].append(i)
tile_ms_list = list(tile_ms) tile_ms_list = list(tile_ms)
tile_ns_list = list(tile_ns) tile_ns_list = list(tile_ns)
tile_ks_list = list(tile_ks) tile_ks_list = list(tile_ks)
tile_ms_list.sort() tile_ms_list.sort()
tile_ns_list.sort() tile_ns_list.sort()
tile_ks_list.sort() tile_ks_list.sort()
self.static_key_to_meta[k] = SimpleGemmAlgoMeta( self.static_key_to_meta[k] = SimpleGemmAlgoMeta(
tile_ms_list, tile_ns_list, tile_ks_list, tile_shape_to_algos) tile_ms_list, tile_ns_list, tile_ks_list, tile_shape_to_algos)
...@@ -482,12 +482,12 @@ class SimpleConv: ...@@ -482,12 +482,12 @@ class SimpleConv:
if tile_key not in tile_shape_to_algos: if tile_key not in tile_shape_to_algos:
tile_shape_to_algos[tile_key] = [] tile_shape_to_algos[tile_key] = []
tile_shape_to_algos[tile_key].append(i) tile_shape_to_algos[tile_key].append(i)
tile_ms_list = list(tile_ms) tile_ms_list = list(tile_ms)
tile_ns_list = list(tile_ns) tile_ns_list = list(tile_ns)
tile_ks_list = list(tile_ks) tile_ks_list = list(tile_ks)
tile_ms_list.sort() tile_ms_list.sort()
tile_ns_list.sort() tile_ns_list.sort()
tile_ks_list.sort() tile_ks_list.sort()
self.static_key_to_meta[k] = SimpleGemmAlgoMeta( self.static_key_to_meta[k] = SimpleGemmAlgoMeta(
tile_ms_list, tile_ns_list, tile_ks_list, tile_shape_to_algos) tile_ms_list, tile_ns_list, tile_ks_list, tile_shape_to_algos)
...@@ -514,10 +514,23 @@ class SimpleConv: ...@@ -514,10 +514,23 @@ class SimpleConv:
out: tv.Tensor, layout_i: ConvLayout, out: tv.Tensor, layout_i: ConvLayout,
layout_w: ConvLayout, layout_o: ConvLayout, layout_w: ConvLayout, layout_o: ConvLayout,
arch: Tuple[int, int], op_type: ConvOpType, arch: Tuple[int, int], op_type: ConvOpType,
mask_width: int): mask_width: int, fp32_accum: Optional[bool] = None):
avail_algos = get_available_algo_str_from_arch(arch) avail_algos = get_available_algo_str_from_arch(arch)
finally_algos: List[ConvAlgoDesp] = [] finally_algos: List[ConvAlgoDesp] = []
is_fp16 = inp.dtype == tv.float16 and weight.dtype == tv.float16 and out.dtype == tv.float16
use_f32_as_accum = False
kv = int(np.prod(weight.shape[1:-1]))
# for 3d conv, if reduce axis is too large, may cause nan during
# forward.
if is_fp16:
if fp32_accum is None:
if op_type == ConvOpType.kForward:
use_f32_as_accum = weight.dim(-1) * kv > 128 * 27
elif op_type == ConvOpType.kBackwardInput:
use_f32_as_accum = weight.dim(0) * kv > 128 * 27
else:
use_f32_as_accum = fp32_accum
for algo in avail_algos: for algo in avail_algos:
static_key = (layout_i.layout_type.value, static_key = (layout_i.layout_type.value,
layout_w.layout_type.value, layout_w.layout_type.value,
...@@ -531,6 +544,14 @@ class SimpleConv: ...@@ -531,6 +544,14 @@ class SimpleConv:
# skip volta tensor op since it is very slow in architectures except volta. # skip volta tensor op since it is very slow in architectures except volta.
if arch >= (7, 5) and desp.algo == GemmAlgo.Volta.value: if arch >= (7, 5) and desp.algo == GemmAlgo.Volta.value:
continue continue
if arch >= (7, 0) and is_fp16:
# skip simt fp16 kernels if we have tensor core
if desp.algo == GemmAlgo.Simt:
continue
if use_f32_as_accum:
if desp.dacc == tv.float16:
continue
ldi = inp.dim(-1) ldi = inp.dim(-1)
ldw = weight.dim(-1) ldw = weight.dim(-1)
ldo = out.dim(-1) ldo = out.dim(-1)
...@@ -589,9 +610,11 @@ class SimpleConv: ...@@ -589,9 +610,11 @@ class SimpleConv:
mask_output: tv.Tensor = tv.Tensor(), mask_output: tv.Tensor = tv.Tensor(),
alpha: float = 1.0, alpha: float = 1.0,
beta: float = 0.0, beta: float = 0.0,
stream: int = 0): stream: int = 0,
fp32_accum: Optional[bool] = None):
avail = self.get_all_available(inp, weight, output, layout_i, layout_w, avail = self.get_all_available(inp, weight, output, layout_i, layout_w,
layout_o, arch, op_type, mask_width) layout_o, arch, op_type, mask_width,
fp32_accum)
inp = inp.clone() inp = inp.clone()
weight = weight.clone() weight = weight.clone()
output = output.clone() output = output.clone()
......
...@@ -26,6 +26,7 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable( ...@@ -26,6 +26,7 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable(
from cumm.gemm.main import GemmMainUnitTest from cumm.gemm.main import GemmMainUnitTest
from cumm.conv.main import ConvMainUnitTest from cumm.conv.main import ConvMainUnitTest
from cumm.common import CompileInfo
from spconv.csrc.sparse.all import SpconvOps from spconv.csrc.sparse.all import SpconvOps
from spconv.csrc.utils import BoxOps from spconv.csrc.utils import BoxOps
...@@ -41,7 +42,7 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable( ...@@ -41,7 +42,7 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable(
if InWindows: if InWindows:
# windows have command line limit, so we use objects_folder to reduce command size. # windows have command line limit, so we use objects_folder to reduce command size.
objects_folder = "objects" objects_folder = "objects"
pccm.builder.build_pybind([cu, convcu, SpconvOps(), BoxOps(), HashTable()], pccm.builder.build_pybind([cu, convcu, SpconvOps(), BoxOps(), HashTable(), CompileInfo()],
PACKAGE_ROOT / "core_cc", PACKAGE_ROOT / "core_cc",
namespace_root=PACKAGE_ROOT, namespace_root=PACKAGE_ROOT,
objects_folder=objects_folder, objects_folder=objects_folder,
......
...@@ -403,7 +403,7 @@ IMPLGEMM_VOLTA_PARAMS = [ ...@@ -403,7 +403,7 @@ IMPLGEMM_VOLTA_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32), *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"], 2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -415,7 +415,7 @@ IMPLGEMM_VOLTA_PARAMS = [ ...@@ -415,7 +415,7 @@ IMPLGEMM_VOLTA_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32), *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"], 2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -427,7 +427,7 @@ IMPLGEMM_VOLTA_PARAMS = [ ...@@ -427,7 +427,7 @@ IMPLGEMM_VOLTA_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32), *gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"], 2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -439,7 +439,7 @@ IMPLGEMM_VOLTA_PARAMS = [ ...@@ -439,7 +439,7 @@ IMPLGEMM_VOLTA_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (32, 256, 32), (32, 64, 32), *gen_conv_params(ConvFwdAndBwdInput, (32, 256, 32), (32, 64, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"], 2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -490,7 +490,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -490,7 +490,7 @@ IMPLGEMM_TURING_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (32, 64, 32), (32, 32, 16), *gen_conv_params(ConvFwdAndBwdInput, (32, 64, 32), (32, 32, 16),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"], 2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -502,7 +502,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -502,7 +502,7 @@ IMPLGEMM_TURING_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (32, 64, 32), (32, 32, 16), *gen_conv_params(ConvFwdAndBwdInput, (32, 64, 32), (32, 32, 16),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"], 2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -514,7 +514,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -514,7 +514,7 @@ IMPLGEMM_TURING_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (32, 256, 32), (32, 64, 32), *gen_conv_params(ConvFwdAndBwdInput, (32, 256, 32), (32, 64, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"], 2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -526,7 +526,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -526,7 +526,7 @@ IMPLGEMM_TURING_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (32, 128, 32), (32, 32, 32), *gen_conv_params(ConvFwdAndBwdInput, (32, 128, 32), (32, 32, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"], 2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -538,7 +538,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -538,7 +538,7 @@ IMPLGEMM_TURING_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (32, 128, 64), (32, 32, 32), *gen_conv_params(ConvFwdAndBwdInput, (32, 128, 64), (32, 32, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"], 2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -550,7 +550,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -550,7 +550,7 @@ IMPLGEMM_TURING_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (32, 128, 64), (32, 64, 32), *gen_conv_params(ConvFwdAndBwdInput, (32, 128, 64), (32, 64, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"], 2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -562,7 +562,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -562,7 +562,7 @@ IMPLGEMM_TURING_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (32, 128, 64), (32, 32, 64), *gen_conv_params(ConvFwdAndBwdInput, (32, 128, 64), (32, 32, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"], 2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -574,7 +574,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -574,7 +574,7 @@ IMPLGEMM_TURING_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (32, 128, 64), (32, 64, 64), *gen_conv_params(ConvFwdAndBwdInput, (32, 128, 64), (32, 64, 64),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"], 2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -586,7 +586,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -586,7 +586,7 @@ IMPLGEMM_TURING_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32), *gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"], 2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -598,7 +598,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -598,7 +598,7 @@ IMPLGEMM_TURING_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 32), *gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"], 2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
...@@ -610,7 +610,7 @@ IMPLGEMM_TURING_PARAMS = [ ...@@ -610,7 +610,7 @@ IMPLGEMM_TURING_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32), *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE, NDIM_DONT_CARE,
ConvIterAlgo.Optimized, ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"], 2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC, NHWC,
NHWC, NHWC,
NHWC, NHWC,
......
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
class CompileInfo:
@staticmethod
def get_compiled_cuda_arch() -> List[Tuple[int, int]]: ...
...@@ -20,5 +20,7 @@ else: ...@@ -20,5 +20,7 @@ else:
CPU_ONLY_BUILD = True CPU_ONLY_BUILD = True
from spconv.core_cc.csrc.utils.boxops import BoxOps from spconv.core_cc.csrc.utils.boxops import BoxOps
from spconv.core_cc.cumm.common import CompileInfo
HAS_BOOST = BoxOps.has_boost()
HAS_BOOST = BoxOps.has_boost() COMPILED_CUDA_ARCHS = set(CompileInfo.get_compiled_cuda_arch())
\ No newline at end of file
...@@ -36,8 +36,6 @@ from spconv.utils import nullcontext ...@@ -36,8 +36,6 @@ from spconv.utils import nullcontext
from torch.nn.init import calculate_gain from torch.nn.init import calculate_gain
class SparseConvolution(SparseModule): class SparseConvolution(SparseModule):
__constants__ = [ __constants__ = [
'stride', 'padding', 'dilation', 'groups', 'bias', 'subm', 'inverse', 'stride', 'padding', 'dilation', 'groups', 'bias', 'subm', 'inverse',
...@@ -60,6 +58,7 @@ class SparseConvolution(SparseModule): ...@@ -60,6 +58,7 @@ class SparseConvolution(SparseModule):
inverse: bool = False, inverse: bool = False,
indice_key: Optional[str] = None, indice_key: Optional[str] = None,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None): name=None):
super(SparseConvolution, self).__init__(name=name) super(SparseConvolution, self).__init__(name=name)
assert groups == 1, "don't support groups for now" assert groups == 1, "don't support groups for now"
...@@ -78,7 +77,9 @@ class SparseConvolution(SparseModule): ...@@ -78,7 +77,9 @@ class SparseConvolution(SparseModule):
if not subm: if not subm:
self.conv1x1 &= kv_stride == 1 self.conv1x1 &= kv_stride == 1
if self.conv1x1: if self.conv1x1:
assert self.padding == [0] * ndim, "padding must be zero for 1x1 conv (k=1,s=1)" assert self.padding == [
0
] * ndim, "padding must be zero for 1x1 conv (k=1,s=1)"
self.transposed = transposed self.transposed = transposed
self.inverse = inverse self.inverse = inverse
self.output_padding = expand_nd(ndim, output_padding) self.output_padding = expand_nd(ndim, output_padding)
...@@ -98,6 +99,7 @@ class SparseConvolution(SparseModule): ...@@ -98,6 +99,7 @@ class SparseConvolution(SparseModule):
if CPU_ONLY_BUILD: if CPU_ONLY_BUILD:
assert algo == ConvAlgo.Native, "cpu only build only support native algorithm" assert algo == ConvAlgo.Native, "cpu only build only support native algorithm"
self.algo = algo self.algo = algo
self.fp32_accum = fp32_accum
# self.algo = ConvAlgo.Native # self.algo = ConvAlgo.Native
if self.algo == ConvAlgo.Native: if self.algo == ConvAlgo.Native:
if FILTER_HWIO: if FILTER_HWIO:
...@@ -150,18 +152,25 @@ class SparseConvolution(SparseModule): ...@@ -150,18 +152,25 @@ class SparseConvolution(SparseModule):
mode = mode.lower() mode = mode.lower()
valid_modes = ['fan_in', 'fan_out'] valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes: if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) raise ValueError(
"Mode {} not supported, please use one of {}".format(
mode, valid_modes))
fan_in, fan_out = self._calculate_fan_in_and_fan_out() fan_in, fan_out = self._calculate_fan_in_and_fan_out()
return fan_in if mode == 'fan_in' else fan_out return fan_in if mode == 'fan_in' else fan_out
def _custom_kaiming_uniform_(self, tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'): def _custom_kaiming_uniform_(self,
tensor,
a=0,
mode='fan_in',
nonlinearity='leaky_relu'):
r"""same as torch.init.kaiming_uniform_, with KRSC layout support r"""same as torch.init.kaiming_uniform_, with KRSC layout support
""" """
fan = self._calculate_correct_fan(mode) fan = self._calculate_correct_fan(mode)
gain = calculate_gain(nonlinearity, a) gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan) std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation bound = math.sqrt(
3.0) * std # Calculate uniform bounds from standard deviation
with torch.no_grad(): with torch.no_grad():
return tensor.uniform_(-bound, bound) return tensor.uniform_(-bound, bound)
...@@ -268,7 +277,8 @@ class SparseConvolution(SparseModule): ...@@ -268,7 +277,8 @@ class SparseConvolution(SparseModule):
indice_pairs = datas.indice_pairs indice_pairs = datas.indice_pairs
indice_pair_num = datas.indice_pair_num indice_pair_num = datas.indice_pair_num
assert self.subm, "only support reuse subm indices" assert self.subm, "only support reuse subm indices"
self._check_subm_reuse_valid(input, spatial_shape, datas) self._check_subm_reuse_valid(input, spatial_shape,
datas)
else: else:
if input.benchmark: if input.benchmark:
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -287,7 +297,7 @@ class SparseConvolution(SparseModule): ...@@ -287,7 +297,7 @@ class SparseConvolution(SparseModule):
msg += f"transpose={self.transposed}" msg += f"transpose={self.transposed}"
print(msg, file=sys.stderr) print(msg, file=sys.stderr)
spconv_save_debug_data(indices) spconv_save_debug_data(indices)
raise e raise e
if input.benchmark: if input.benchmark:
torch.cuda.synchronize() torch.cuda.synchronize()
interval = time.time() - t interval = time.time() - t
...@@ -360,7 +370,8 @@ class SparseConvolution(SparseModule): ...@@ -360,7 +370,8 @@ class SparseConvolution(SparseModule):
mask_argsort_bwd_splits = datas.mask_argsort_bwd_splits mask_argsort_bwd_splits = datas.mask_argsort_bwd_splits
masks = datas.masks masks = datas.masks
assert self.subm, "only support reuse subm indices" assert self.subm, "only support reuse subm indices"
self._check_subm_reuse_valid(input, spatial_shape, datas) self._check_subm_reuse_valid(input, spatial_shape,
datas)
else: else:
with input._timer.namespace("gen_pairs"): with input._timer.namespace("gen_pairs"):
...@@ -390,7 +401,7 @@ class SparseConvolution(SparseModule): ...@@ -390,7 +401,7 @@ class SparseConvolution(SparseModule):
msg += f"transpose={self.transposed}" msg += f"transpose={self.transposed}"
print(msg, file=sys.stderr) print(msg, file=sys.stderr)
spconv_save_debug_data(indices) spconv_save_debug_data(indices)
raise e raise e
outids = res[0] outids = res[0]
num_inds_per_loc = res[1] num_inds_per_loc = res[1]
...@@ -432,7 +443,7 @@ class SparseConvolution(SparseModule): ...@@ -432,7 +443,7 @@ class SparseConvolution(SparseModule):
pair_mask_fwd_splits, pair_mask_bwd_splits, pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits, mask_argsort_fwd_splits, mask_argsort_bwd_splits,
num_activate_out, masks, self.training, self.subm, num_activate_out, masks, self.training, self.subm,
input._timer) input._timer, self.fp32_accum)
if self.bias is not None: if self.bias is not None:
out_features += self.bias out_features += self.bias
if input.benchmark: if input.benchmark:
...@@ -449,21 +460,28 @@ class SparseConvolution(SparseModule): ...@@ -449,21 +460,28 @@ class SparseConvolution(SparseModule):
out_tensor.spatial_shape = out_spatial_shape out_tensor.spatial_shape = out_spatial_shape
return out_tensor return out_tensor
def _check_subm_reuse_valid(self, inp: SparseConvTensor,
def _check_subm_reuse_valid(self, inp: SparseConvTensor, spatial_shape: List[int], datas: Union[ImplicitGemmIndiceData, IndiceData]): spatial_shape: List[int],
datas: Union[ImplicitGemmIndiceData,
IndiceData]):
assert datas.is_subm, "only support reuse subm indices" assert datas.is_subm, "only support reuse subm indices"
if self.kernel_size != datas.ksize: if self.kernel_size != datas.ksize:
raise ValueError(f"subm with same indice_key must have same kernel" raise ValueError(
f"subm with same indice_key must have same kernel"
f" size, expect {datas.ksize}, this layer {self.kernel_size}") f" size, expect {datas.ksize}, this layer {self.kernel_size}")
if self.dilation != datas.dilation: if self.dilation != datas.dilation:
raise ValueError(f"subm with same indice_key must have same dilation" raise ValueError(
f"subm with same indice_key must have same dilation"
f", expect {datas.dilation}, this layer {self.dilation}") f", expect {datas.dilation}, this layer {self.dilation}")
if inp.spatial_shape != datas.spatial_shape: if inp.spatial_shape != datas.spatial_shape:
raise ValueError(f"subm with same indice_key must have same spatial structure" raise ValueError(
f"subm with same indice_key must have same spatial structure"
f", expect {datas.spatial_shape}, input {spatial_shape}") f", expect {datas.spatial_shape}, input {spatial_shape}")
if inp.indices.shape[0] != datas.indices.shape[0]: if inp.indices.shape[0] != datas.indices.shape[0]:
raise ValueError(f"subm with same indice_key must have same num of indices" raise ValueError(
f", expect {datas.indices.shape[0]}, input {inp.indices.shape[0]}") f"subm with same indice_key must have same num of indices"
f", expect {datas.indices.shape[0]}, input {inp.indices.shape[0]}"
)
class SparseConv1d(SparseConvolution): class SparseConv1d(SparseConvolution):
...@@ -478,6 +496,7 @@ class SparseConv1d(SparseConvolution): ...@@ -478,6 +496,7 @@ class SparseConv1d(SparseConvolution):
bias=True, bias=True,
indice_key=None, indice_key=None,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None): name=None):
super(SparseConv1d, self).__init__(1, super(SparseConv1d, self).__init__(1,
in_channels, in_channels,
...@@ -490,6 +509,7 @@ class SparseConv1d(SparseConvolution): ...@@ -490,6 +509,7 @@ class SparseConv1d(SparseConvolution):
bias, bias,
indice_key=indice_key, indice_key=indice_key,
algo=algo, algo=algo,
fp32_accum=fp32_accum,
name=name) name=name)
...@@ -505,6 +525,7 @@ class SparseConv2d(SparseConvolution): ...@@ -505,6 +525,7 @@ class SparseConv2d(SparseConvolution):
bias=True, bias=True,
indice_key=None, indice_key=None,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None): name=None):
super(SparseConv2d, self).__init__(2, super(SparseConv2d, self).__init__(2,
in_channels, in_channels,
...@@ -517,6 +538,7 @@ class SparseConv2d(SparseConvolution): ...@@ -517,6 +538,7 @@ class SparseConv2d(SparseConvolution):
bias, bias,
indice_key=indice_key, indice_key=indice_key,
algo=algo, algo=algo,
fp32_accum=fp32_accum,
name=name) name=name)
...@@ -532,6 +554,7 @@ class SparseConv3d(SparseConvolution): ...@@ -532,6 +554,7 @@ class SparseConv3d(SparseConvolution):
bias=True, bias=True,
indice_key=None, indice_key=None,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None): name=None):
super(SparseConv3d, self).__init__(3, super(SparseConv3d, self).__init__(3,
in_channels, in_channels,
...@@ -544,6 +567,7 @@ class SparseConv3d(SparseConvolution): ...@@ -544,6 +567,7 @@ class SparseConv3d(SparseConvolution):
bias, bias,
indice_key=indice_key, indice_key=indice_key,
algo=algo, algo=algo,
fp32_accum=fp32_accum,
name=name) name=name)
...@@ -559,6 +583,7 @@ class SparseConv4d(SparseConvolution): ...@@ -559,6 +583,7 @@ class SparseConv4d(SparseConvolution):
bias=True, bias=True,
indice_key=None, indice_key=None,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None): name=None):
super(SparseConv4d, self).__init__(4, super(SparseConv4d, self).__init__(4,
in_channels, in_channels,
...@@ -571,6 +596,7 @@ class SparseConv4d(SparseConvolution): ...@@ -571,6 +596,7 @@ class SparseConv4d(SparseConvolution):
bias, bias,
indice_key=indice_key, indice_key=indice_key,
algo=algo, algo=algo,
fp32_accum=fp32_accum,
name=name) name=name)
...@@ -586,6 +612,7 @@ class SparseConvTranspose1d(SparseConvolution): ...@@ -586,6 +612,7 @@ class SparseConvTranspose1d(SparseConvolution):
bias=True, bias=True,
indice_key=None, indice_key=None,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None): name=None):
super(SparseConvTranspose1d, self).__init__(1, super(SparseConvTranspose1d, self).__init__(1,
in_channels, in_channels,
...@@ -599,6 +626,7 @@ class SparseConvTranspose1d(SparseConvolution): ...@@ -599,6 +626,7 @@ class SparseConvTranspose1d(SparseConvolution):
transposed=True, transposed=True,
indice_key=indice_key, indice_key=indice_key,
algo=algo, algo=algo,
fp32_accum=fp32_accum,
name=name) name=name)
...@@ -614,6 +642,7 @@ class SparseConvTranspose2d(SparseConvolution): ...@@ -614,6 +642,7 @@ class SparseConvTranspose2d(SparseConvolution):
bias=True, bias=True,
indice_key=None, indice_key=None,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None): name=None):
super(SparseConvTranspose2d, self).__init__(2, super(SparseConvTranspose2d, self).__init__(2,
in_channels, in_channels,
...@@ -627,6 +656,7 @@ class SparseConvTranspose2d(SparseConvolution): ...@@ -627,6 +656,7 @@ class SparseConvTranspose2d(SparseConvolution):
transposed=True, transposed=True,
indice_key=indice_key, indice_key=indice_key,
algo=algo, algo=algo,
fp32_accum=fp32_accum,
name=name) name=name)
...@@ -642,6 +672,7 @@ class SparseConvTranspose3d(SparseConvolution): ...@@ -642,6 +672,7 @@ class SparseConvTranspose3d(SparseConvolution):
bias=True, bias=True,
indice_key=None, indice_key=None,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None): name=None):
super(SparseConvTranspose3d, self).__init__(3, super(SparseConvTranspose3d, self).__init__(3,
in_channels, in_channels,
...@@ -655,6 +686,7 @@ class SparseConvTranspose3d(SparseConvolution): ...@@ -655,6 +686,7 @@ class SparseConvTranspose3d(SparseConvolution):
transposed=True, transposed=True,
indice_key=indice_key, indice_key=indice_key,
algo=algo, algo=algo,
fp32_accum=fp32_accum,
name=name) name=name)
...@@ -670,6 +702,7 @@ class SparseConvTranspose4d(SparseConvolution): ...@@ -670,6 +702,7 @@ class SparseConvTranspose4d(SparseConvolution):
bias=True, bias=True,
indice_key=None, indice_key=None,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None): name=None):
super(SparseConvTranspose4d, self).__init__(4, super(SparseConvTranspose4d, self).__init__(4,
in_channels, in_channels,
...@@ -683,6 +716,7 @@ class SparseConvTranspose4d(SparseConvolution): ...@@ -683,6 +716,7 @@ class SparseConvTranspose4d(SparseConvolution):
transposed=True, transposed=True,
indice_key=indice_key, indice_key=indice_key,
algo=algo, algo=algo,
fp32_accum=fp32_accum,
name=name) name=name)
...@@ -694,6 +728,7 @@ class SparseInverseConv1d(SparseConvolution): ...@@ -694,6 +728,7 @@ class SparseInverseConv1d(SparseConvolution):
indice_key, indice_key,
bias=True, bias=True,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None): name=None):
super(SparseInverseConv1d, self).__init__(1, super(SparseInverseConv1d, self).__init__(1,
in_channels, in_channels,
...@@ -703,6 +738,7 @@ class SparseInverseConv1d(SparseConvolution): ...@@ -703,6 +738,7 @@ class SparseInverseConv1d(SparseConvolution):
inverse=True, inverse=True,
indice_key=indice_key, indice_key=indice_key,
algo=algo, algo=algo,
fp32_accum=fp32_accum,
name=name) name=name)
...@@ -714,6 +750,7 @@ class SparseInverseConv2d(SparseConvolution): ...@@ -714,6 +750,7 @@ class SparseInverseConv2d(SparseConvolution):
indice_key, indice_key,
bias=True, bias=True,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None): name=None):
super(SparseInverseConv2d, self).__init__(2, super(SparseInverseConv2d, self).__init__(2,
in_channels, in_channels,
...@@ -723,6 +760,7 @@ class SparseInverseConv2d(SparseConvolution): ...@@ -723,6 +760,7 @@ class SparseInverseConv2d(SparseConvolution):
inverse=True, inverse=True,
indice_key=indice_key, indice_key=indice_key,
algo=algo, algo=algo,
fp32_accum=fp32_accum,
name=name) name=name)
...@@ -734,6 +772,7 @@ class SparseInverseConv3d(SparseConvolution): ...@@ -734,6 +772,7 @@ class SparseInverseConv3d(SparseConvolution):
indice_key, indice_key,
bias=True, bias=True,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None): name=None):
super(SparseInverseConv3d, self).__init__(3, super(SparseInverseConv3d, self).__init__(3,
in_channels, in_channels,
...@@ -743,6 +782,7 @@ class SparseInverseConv3d(SparseConvolution): ...@@ -743,6 +782,7 @@ class SparseInverseConv3d(SparseConvolution):
inverse=True, inverse=True,
indice_key=indice_key, indice_key=indice_key,
algo=algo, algo=algo,
fp32_accum=fp32_accum,
name=name) name=name)
...@@ -754,6 +794,7 @@ class SparseInverseConv4d(SparseConvolution): ...@@ -754,6 +794,7 @@ class SparseInverseConv4d(SparseConvolution):
indice_key, indice_key,
bias=True, bias=True,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None): name=None):
super(SparseInverseConv4d, self).__init__(4, super(SparseInverseConv4d, self).__init__(4,
in_channels, in_channels,
...@@ -763,6 +804,7 @@ class SparseInverseConv4d(SparseConvolution): ...@@ -763,6 +804,7 @@ class SparseInverseConv4d(SparseConvolution):
inverse=True, inverse=True,
indice_key=indice_key, indice_key=indice_key,
algo=algo, algo=algo,
fp32_accum=fp32_accum,
name=name) name=name)
...@@ -778,6 +820,7 @@ class SubMConv1d(SparseConvolution): ...@@ -778,6 +820,7 @@ class SubMConv1d(SparseConvolution):
bias=True, bias=True,
indice_key=None, indice_key=None,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None): name=None):
super(SubMConv1d, self).__init__(1, super(SubMConv1d, self).__init__(1,
in_channels, in_channels,
...@@ -791,6 +834,7 @@ class SubMConv1d(SparseConvolution): ...@@ -791,6 +834,7 @@ class SubMConv1d(SparseConvolution):
True, True,
indice_key=indice_key, indice_key=indice_key,
algo=algo, algo=algo,
fp32_accum=fp32_accum,
name=name) name=name)
...@@ -806,6 +850,7 @@ class SubMConv2d(SparseConvolution): ...@@ -806,6 +850,7 @@ class SubMConv2d(SparseConvolution):
bias=True, bias=True,
indice_key=None, indice_key=None,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None): name=None):
super(SubMConv2d, self).__init__(2, super(SubMConv2d, self).__init__(2,
in_channels, in_channels,
...@@ -819,6 +864,7 @@ class SubMConv2d(SparseConvolution): ...@@ -819,6 +864,7 @@ class SubMConv2d(SparseConvolution):
True, True,
indice_key=indice_key, indice_key=indice_key,
algo=algo, algo=algo,
fp32_accum=fp32_accum,
name=name) name=name)
...@@ -834,6 +880,7 @@ class SubMConv3d(SparseConvolution): ...@@ -834,6 +880,7 @@ class SubMConv3d(SparseConvolution):
bias=True, bias=True,
indice_key=None, indice_key=None,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None): name=None):
super(SubMConv3d, self).__init__(3, super(SubMConv3d, self).__init__(3,
in_channels, in_channels,
...@@ -847,6 +894,7 @@ class SubMConv3d(SparseConvolution): ...@@ -847,6 +894,7 @@ class SubMConv3d(SparseConvolution):
True, True,
indice_key=indice_key, indice_key=indice_key,
algo=algo, algo=algo,
fp32_accum=fp32_accum,
name=name) name=name)
...@@ -862,6 +910,7 @@ class SubMConv4d(SparseConvolution): ...@@ -862,6 +910,7 @@ class SubMConv4d(SparseConvolution):
bias=True, bias=True,
indice_key=None, indice_key=None,
algo: Optional[ConvAlgo] = None, algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None): name=None):
super(SubMConv4d, self).__init__(4, super(SubMConv4d, self).__init__(4,
in_channels, in_channels,
...@@ -875,4 +924,5 @@ class SubMConv4d(SparseConvolution): ...@@ -875,4 +924,5 @@ class SubMConv4d(SparseConvolution):
True, True,
indice_key=indice_key, indice_key=indice_key,
algo=algo, algo=algo,
fp32_accum=fp32_accum,
name=name) name=name)
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
from cumm import tensorview as tv from cumm import tensorview as tv
import torch import torch
from typing import Optional, List from typing import Optional, List
from spconv.cppconstants import COMPILED_CUDA_ARCHS
import sys
_TORCH_DTYPE_TO_TV = { _TORCH_DTYPE_TO_TV = {
torch.float32: tv.float32, torch.float32: tv.float32,
...@@ -53,7 +55,14 @@ def torch_tensors_to_tv(*tens: torch.Tensor): ...@@ -53,7 +55,14 @@ def torch_tensors_to_tv(*tens: torch.Tensor):
def get_current_stream(): def get_current_stream():
return torch.cuda.current_stream().cuda_stream return torch.cuda.current_stream().cuda_stream
def get_arch():
arch = torch.cuda.get_device_capability()
if arch not in COMPILED_CUDA_ARCHS:
print(f"[WARNING]your gpu arch {arch} isn't compiled in prebuilt, "
f"may cause invalid device function. "
f"available: {COMPILED_CUDA_ARCHS}", file=sys.stderr)
return arch
if __name__ == "__main__": if __name__ == "__main__":
a = torch.rand(2, 2) a = torch.rand(2, 2)
atv = torch_tensor_to_tv(a) atv = torch_tensor_to_tv(a)
......
...@@ -179,14 +179,16 @@ class SparseImplicitGemmFunction(Function): ...@@ -179,14 +179,16 @@ class SparseImplicitGemmFunction(Function):
masks: List[np.ndarray], masks: List[np.ndarray],
is_train: bool, is_train: bool,
is_subm: bool, is_subm: bool,
timer: CUDAKernelTimer = CUDAKernelTimer(False)): timer: CUDAKernelTimer = CUDAKernelTimer(False),
fp32_accum: Optional[bool] = None):
try: try:
out, mask_out, mask_width = ops.implicit_gemm(features, filters, out, mask_out, mask_width = ops.implicit_gemm(features, filters,
pair_fwd, pair_fwd,
pair_mask_fwd_splits, pair_mask_fwd_splits,
mask_argsort_fwd_splits, mask_argsort_fwd_splits,
num_activate_out, masks, num_activate_out, masks,
is_train, is_subm, timer) is_train, is_subm, timer,
fp32_accum)
except Exception as e: except Exception as e:
msg = "[Exception|implicit_gemm]" msg = "[Exception|implicit_gemm]"
msg += f"feat={features.shape},w={filters.shape},pair={pair_fwd.shape}," msg += f"feat={features.shape},w={filters.shape},pair={pair_fwd.shape},"
...@@ -208,6 +210,7 @@ class SparseImplicitGemmFunction(Function): ...@@ -208,6 +210,7 @@ class SparseImplicitGemmFunction(Function):
# ctx.num_activate_out = num_activate_out # ctx.num_activate_out = num_activate_out
ctx.masks = masks ctx.masks = masks
ctx.is_subm = is_subm ctx.is_subm = is_subm
ctx.fp32_accum = fp32_accum
return out return out
@staticmethod @staticmethod
...@@ -225,6 +228,8 @@ class SparseImplicitGemmFunction(Function): ...@@ -225,6 +228,8 @@ class SparseImplicitGemmFunction(Function):
masks = ctx.masks masks = ctx.masks
is_subm = ctx.is_subm is_subm = ctx.is_subm
timer = ctx.timer timer = ctx.timer
fp32_accum = ctx.fp32_accum
try: try:
input_bp, filters_bp = ops.implicit_gemm_backward( input_bp, filters_bp = ops.implicit_gemm_backward(
features, features,
...@@ -240,7 +245,8 @@ class SparseImplicitGemmFunction(Function): ...@@ -240,7 +245,8 @@ class SparseImplicitGemmFunction(Function):
masks=masks, masks=masks,
mask_width=mask_width, mask_width=mask_width,
is_subm=is_subm, is_subm=is_subm,
timer=timer) timer=timer,
fp32_accum=fp32_accum)
except Exception as e: except Exception as e:
msg = "[Exception|implicit_gemm_backward]" msg = "[Exception|implicit_gemm_backward]"
msg += f"feat={features.shape},w={filters.shape},pair={pair_fwd.shape}," msg += f"feat={features.shape},w={filters.shape},pair={pair_fwd.shape},"
...@@ -251,7 +257,7 @@ class SparseImplicitGemmFunction(Function): ...@@ -251,7 +257,7 @@ class SparseImplicitGemmFunction(Function):
masks)) masks))
raise e raise e
None_9 = [None] * 11 None_9 = [None] * 12
return (input_bp, filters_bp, *None_9) return (input_bp, filters_bp, *None_9)
......
...@@ -23,7 +23,7 @@ import spconv ...@@ -23,7 +23,7 @@ import spconv
from spconv.core import AlgoHint, ConvAlgo from spconv.core import AlgoHint, ConvAlgo
from typing import List, Optional, Union from typing import List, Optional, Union
from spconv.pytorch.core import ThrustSortAllocator from spconv.pytorch.core import ThrustSortAllocator
from spconv.pytorch.cppcore import torch_tensor_to_tv, get_current_stream from spconv.pytorch.cppcore import torch_tensor_to_tv, get_current_stream, get_arch
from spconv.core_cc.csrc.sparse.all import SpconvOps from spconv.core_cc.csrc.sparse.all import SpconvOps
import spconv.core_cc as _ext import spconv.core_cc as _ext
...@@ -666,7 +666,7 @@ def indice_conv(features: torch.Tensor, ...@@ -666,7 +666,7 @@ def indice_conv(features: torch.Tensor,
profile_idx = i profile_idx = i
assert nhot_profile > 0, "this shouldn't happen" assert nhot_profile > 0, "this shouldn't happen"
# print(nhot_profile, indice_pair_num_cpu) # print(nhot_profile, indice_pair_num_cpu)
arch = torch.cuda.get_device_capability() arch = get_arch()
tuned_res = GEMM.get_tuned_algo(a.dtype, tuned_res = GEMM.get_tuned_algo(a.dtype,
filters_tv.dtype, filters_tv.dtype,
...@@ -809,7 +809,7 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -809,7 +809,7 @@ def indice_conv_backward(features: torch.Tensor,
return (din, dfilters.reshape(filters_shape)) return (din, dfilters.reshape(filters_shape))
maxnhot = max(indice_pair_num_cpu) maxnhot = max(indice_pair_num_cpu)
arch = torch.cuda.get_device_capability() arch = get_arch()
filters_tv = torch_tensor_to_tv(filters) filters_tv = torch_tensor_to_tv(filters)
dfilters_tv = torch_tensor_to_tv(dfilters) dfilters_tv = torch_tensor_to_tv(dfilters)
...@@ -1051,7 +1051,8 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1051,7 +1051,8 @@ def implicit_gemm(features: torch.Tensor,
masks: List[np.ndarray], masks: List[np.ndarray],
is_train: bool, is_train: bool,
is_subm: bool, is_subm: bool,
timer: CUDAKernelTimer = CUDAKernelTimer(False)): timer: CUDAKernelTimer = CUDAKernelTimer(False),
fp32_accum: Optional[bool] = None):
stream = get_current_stream() stream = get_current_stream()
# if DEBUG: # if DEBUG:
...@@ -1085,7 +1086,7 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1085,7 +1086,7 @@ def implicit_gemm(features: torch.Tensor,
features_tv = torch_tensor_to_tv(features) features_tv = torch_tensor_to_tv(features)
filters_tv = torch_tensor_to_tv(filters) filters_tv = torch_tensor_to_tv(filters)
out_features_tv = torch_tensor_to_tv(out_features) out_features_tv = torch_tensor_to_tv(out_features)
arch = torch.cuda.get_device_capability() arch = get_arch()
pair_mask_fwd_split_tvs = [ pair_mask_fwd_split_tvs = [
torch_tensor_to_tv(x, dtype=tv.uint32) for x in pair_mask_fwd_splits torch_tensor_to_tv(x, dtype=tv.uint32) for x in pair_mask_fwd_splits
] ]
...@@ -1113,7 +1114,8 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1113,7 +1114,8 @@ def implicit_gemm(features: torch.Tensor,
indices=pair_fwd_tv, indices=pair_fwd_tv,
reverse_mask=False, reverse_mask=False,
mask_filter=masks[0].item(), mask_filter=masks[0].item(),
stream=stream) stream=stream,
fp32_accum=fp32_accum)
mask_width = tune_res.algo_desp.tile_shape[0] mask_width = tune_res.algo_desp.tile_shape[0]
if is_train: if is_train:
mask_output_fwd = torch.empty( mask_output_fwd = torch.empty(
...@@ -1180,7 +1182,8 @@ def implicit_gemm_backward(features: torch.Tensor, ...@@ -1180,7 +1182,8 @@ def implicit_gemm_backward(features: torch.Tensor,
masks: List[np.ndarray], masks: List[np.ndarray],
mask_width: int, mask_width: int,
is_subm: bool, is_subm: bool,
timer: CUDAKernelTimer = CUDAKernelTimer(False)): timer: CUDAKernelTimer = CUDAKernelTimer(False),
fp32_accum: Optional[bool] = None):
# print(out_bp.mean(), out_bp.max(), out_bp.min()) # print(out_bp.mean(), out_bp.max(), out_bp.min())
if features.dtype == torch.int8 or features.dtype == torch.qint8: if features.dtype == torch.int8 or features.dtype == torch.qint8:
raise NotImplementedError("work in progress") raise NotImplementedError("work in progress")
...@@ -1217,7 +1220,7 @@ def implicit_gemm_backward(features: torch.Tensor, ...@@ -1217,7 +1220,7 @@ def implicit_gemm_backward(features: torch.Tensor,
dout_tv = torch_tensor_to_tv(out_bp) dout_tv = torch_tensor_to_tv(out_bp)
din_tv = torch_tensor_to_tv(din) din_tv = torch_tensor_to_tv(din)
mask_output_fwd_tv = torch_tensor_to_tv(mask_output_fwd, dtype=tv.uint32) mask_output_fwd_tv = torch_tensor_to_tv(mask_output_fwd, dtype=tv.uint32)
arch = torch.cuda.get_device_capability() arch = get_arch()
pair_mask_fwd_split_tvs = [ pair_mask_fwd_split_tvs = [
torch_tensor_to_tv(x, dtype=tv.uint32) for x in pair_mask_fwd_splits torch_tensor_to_tv(x, dtype=tv.uint32) for x in pair_mask_fwd_splits
] ]
...@@ -1263,7 +1266,8 @@ def implicit_gemm_backward(features: torch.Tensor, ...@@ -1263,7 +1266,8 @@ def implicit_gemm_backward(features: torch.Tensor,
indices=pair_bwd_tv, indices=pair_bwd_tv,
reverse_mask=is_subm, reverse_mask=is_subm,
mask_filter=masks[0].item(), mask_filter=masks[0].item(),
stream=stream) stream=stream,
fp32_accum=fp32_accum)
if wgrad_tune_res is None: if wgrad_tune_res is None:
wgrad_tune_res, _ = CONV.tune_and_cache( wgrad_tune_res, _ = CONV.tune_and_cache(
ConvOpType.kBackwardWeight, ConvOpType.kBackwardWeight,
......
...@@ -289,7 +289,7 @@ def main(): ...@@ -289,7 +289,7 @@ def main():
voxels_th = torch.from_numpy(voxels).to(device).to(dtype) voxels_th = torch.from_numpy(voxels).to(device).to(dtype)
coors_th = torch.from_numpy(coors).to(device).int() coors_th = torch.from_numpy(coors).to(device).int()
voxels_th.requires_grad = True voxels_th.requires_grad = True
algo = spconv.ConvAlgo.Native algo = spconv.ConvAlgo.MaskImplicitGemm
# 3080 Laptop # 3080 Laptop
# MaskImpGemm: 11.2ms # MaskImpGemm: 11.2ms
# MaskSplitImpGemm: 12.2ms # MaskSplitImpGemm: 12.2ms
......
2.1.20 2.1.21
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment