Commit 863080a0 authored by yan.yan's avatar yan.yan
Browse files

Merge branch 'master' into develop

parents 06a01f0f 52594038
# Changelog
## [2.1.21] - 2021-12-9
### Added
- add sm_37
- add fp16 kernels witl fp32 accumulator (run slower, but can avoid nan if channel size is too large)
## [2.1.20] - 2021-12-6
### Added
- Add fp16 conv simt kernels for mixed-training in pascal or older GPUS. WARNING: not optimized for TESLA P100 which has 2x throughput in half.
## [2.1.19] - 2021-12-3
### Fixed
- Fix wrong arch assert in all kernels for old GPUs to make spconv work in sm_50 GPUs
......
......@@ -18,7 +18,7 @@
## Short Guide
* If you train without Tensor Core (i.e. FP32 training), set all ```algo``` in convolution/maxpool to ```ConvAlgo.Native``` manually. Default Algorithm is ```ConvAlgo.MaskImplicitGemm```, which is **SLOWER** than ```ConvAlgo.Native``` when use float32. this will be fixed in spconv 2.2.
* If you train without Tensor Core (i.e. FP32 training or FP16 training for Pascal or older GPUS), set all ```algo``` in convolution/maxpool to ```ConvAlgo.Native``` manually. Default Algorithm is ```ConvAlgo.MaskImplicitGemm```, which is **SLOWER** than ```ConvAlgo.Native``` when use float32. this will be fixed in spconv 2.2.
* If your GPU support Tensor Core, use FP16 (mixed precision training) if possible.
* If you train with mixed precision training (use Tensor Core), you don't need to set algorithm manually.
* Currently fast algorithm only support kernel volume (prod of kernel size) <= 32, so don't use large kernel size.
......
......@@ -38,9 +38,9 @@ if cuda_ver:
cuda_ver = cuda_ver.replace(".", "") # 10.2 to 102
RELEASE_NAME += "-cu{}".format(cuda_ver)
deps = ["cumm-cu{}>=0.2.6".format(cuda_ver)]
deps = ["cumm-cu{}>=0.2.8".format(cuda_ver)]
else:
deps = ["cumm>=0.2.6"]
deps = ["cumm>=0.2.8"]
......@@ -158,6 +158,7 @@ if disable_jit is not None and disable_jit == "1":
from spconv.csrc.sparse.all import SpconvOps
from spconv.csrc.utils import BoxOps
from spconv.csrc.hash.core import HashTable
from cumm.common import CompileInfo
cu = GemmMainUnitTest(SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS)
convcu = ConvMainUnitTest(IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS + IMPLGEMM_TURING_PARAMS)
......@@ -171,9 +172,9 @@ if disable_jit is not None and disable_jit == "1":
std = "c++14"
else:
std = "c++17"
cus = [cu, convcu, SpconvOps(), BoxOps(), HashTable()]
cus = [cu, convcu, SpconvOps(), BoxOps(), HashTable(), CompileInfo()]
if CUMM_CPU_ONLY_BUILD:
cus = [SpconvOps(), BoxOps(), HashTable()]
cus = [SpconvOps(), BoxOps(), HashTable(), CompileInfo()]
ext_modules: List[Extension] = [
PCCMExtension(cus,
"spconv/core_cc",
......
......@@ -77,12 +77,12 @@ class SimpleGemm:
if tile_key not in tile_shape_to_algos:
tile_shape_to_algos[tile_key] = []
tile_shape_to_algos[tile_key].append(i)
tile_ms_list = list(tile_ms)
tile_ns_list = list(tile_ns)
tile_ks_list = list(tile_ks)
tile_ms_list.sort()
tile_ns_list.sort()
tile_ks_list.sort()
tile_ms_list = list(tile_ms)
tile_ns_list = list(tile_ns)
tile_ks_list = list(tile_ks)
tile_ms_list.sort()
tile_ns_list.sort()
tile_ks_list.sort()
self.static_key_to_meta[k] = SimpleGemmAlgoMeta(
tile_ms_list, tile_ns_list, tile_ks_list, tile_shape_to_algos)
......@@ -483,12 +483,12 @@ class SimpleConv:
if tile_key not in tile_shape_to_algos:
tile_shape_to_algos[tile_key] = []
tile_shape_to_algos[tile_key].append(i)
tile_ms_list = list(tile_ms)
tile_ns_list = list(tile_ns)
tile_ks_list = list(tile_ks)
tile_ms_list.sort()
tile_ns_list.sort()
tile_ks_list.sort()
tile_ms_list = list(tile_ms)
tile_ns_list = list(tile_ns)
tile_ks_list = list(tile_ks)
tile_ms_list.sort()
tile_ns_list.sort()
tile_ks_list.sort()
self.static_key_to_meta[k] = SimpleGemmAlgoMeta(
tile_ms_list, tile_ns_list, tile_ks_list, tile_shape_to_algos)
......@@ -515,10 +515,23 @@ class SimpleConv:
out: tv.Tensor, layout_i: ConvLayout,
layout_w: ConvLayout, layout_o: ConvLayout,
arch: Tuple[int, int], op_type: ConvOpType,
mask_width: int):
mask_width: int, fp32_accum: Optional[bool] = None):
avail_algos = get_available_algo_str_from_arch(arch)
finally_algos: List[ConvAlgoDesp] = []
is_fp16 = inp.dtype == tv.float16 and weight.dtype == tv.float16 and out.dtype == tv.float16
use_f32_as_accum = False
kv = int(np.prod(weight.shape[1:-1]))
# for 3d conv, if reduce axis is too large, may cause nan during
# forward.
if is_fp16:
if fp32_accum is None:
if op_type == ConvOpType.kForward:
use_f32_as_accum = weight.dim(-1) * kv > 128 * 27
elif op_type == ConvOpType.kBackwardInput:
use_f32_as_accum = weight.dim(0) * kv > 128 * 27
else:
use_f32_as_accum = fp32_accum
for algo in avail_algos:
static_key = (layout_i.layout_type.value,
layout_w.layout_type.value,
......@@ -532,6 +545,14 @@ class SimpleConv:
# skip volta tensor op since it is very slow in architectures except volta.
if arch >= (7, 5) and desp.algo == GemmAlgo.Volta.value:
continue
if arch >= (7, 0) and is_fp16:
# skip simt fp16 kernels if we have tensor core
if desp.algo == GemmAlgo.Simt:
continue
if use_f32_as_accum:
if desp.dacc == tv.float16:
continue
ldi = inp.dim(-1)
ldw = weight.dim(-1)
ldo = out.dim(-1)
......@@ -590,9 +611,11 @@ class SimpleConv:
mask_output: tv.Tensor = tv.Tensor(),
alpha: float = 1.0,
beta: float = 0.0,
stream: int = 0):
stream: int = 0,
fp32_accum: Optional[bool] = None):
avail = self.get_all_available(inp, weight, output, layout_i, layout_w,
layout_o, arch, op_type, mask_width)
layout_o, arch, op_type, mask_width,
fp32_accum)
inp = inp.clone()
weight = weight.clone()
output = output.clone()
......
......@@ -26,6 +26,7 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable(
from cumm.gemm.main import GemmMainUnitTest
from cumm.conv.main import ConvMainUnitTest
from cumm.common import CompileInfo
from spconv.csrc.sparse.all import SpconvOps
from spconv.csrc.utils import BoxOps
......@@ -41,7 +42,7 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable(
if InWindows:
# windows have command line limit, so we use objects_folder to reduce command size.
objects_folder = "objects"
pccm.builder.build_pybind([cu, convcu, SpconvOps(), BoxOps(), HashTable()],
pccm.builder.build_pybind([cu, convcu, SpconvOps(), BoxOps(), HashTable(), CompileInfo()],
PACKAGE_ROOT / "core_cc",
namespace_root=PACKAGE_ROOT,
objects_folder=objects_folder,
......
......@@ -88,18 +88,18 @@ SHUFFLE_SIMT_PARAMS: List[GemmAlgoParams] = [
# fall back kernels if mat is misaligned for half
# TODO use access-per-vector kernel instead of simt kernel for fallback
*gen_shuffle_params((128, 128, 8), (32, 64, 8), ["f16,f16,f16,f16,f16"],
*gen_shuffle_params((128, 128, 8), (32, 64, 8), ["f16,f16,f16,f32,f32"],
"f16,f16,f16,f32,f32", 2, kernel.GemmAlgo.Simt, None),
*gen_shuffle_params((32, 64, 32), (32, 32, 8), ["f16,f16,f16,f16,f16"],
*gen_shuffle_params((32, 64, 32), (32, 32, 8), ["f16,f16,f16,f32,f32"],
"f16,f16,f16,f32,f32", 2, kernel.GemmAlgo.Simt, None),
*gen_shuffle_params((32, 32, 32), (32, 32, 8), ["f16,f16,f16,f16,f16"],
*gen_shuffle_params((32, 32, 32), (32, 32, 8), ["f16,f16,f16,f32,f32"],
"f16,f16,f16,f32,f32", 2, kernel.GemmAlgo.Simt, None),
# *gen_shuffle_params(
# (64, 64, 16),
# (32, 32, 8), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, kernel.GemmAlgo.Simt, None),
*gen_shuffle_params((64, 128, 16), (32, 64, 8), ["f16,f16,f16,f16,f16"],
*gen_shuffle_params((64, 128, 16), (32, 64, 8), ["f16,f16,f16,f32,f32"],
"f16,f16,f16,f32,f32", 2, kernel.GemmAlgo.Simt, None),
*gen_shuffle_params((64, 64, 8), (32, 32, 8), ["f16,f16,f16,f16,f16"],
*gen_shuffle_params((64, 64, 8), (32, 32, 8), ["f16,f16,f16,f32,f32"],
"f16,f16,f16,f32,f32", 2, kernel.GemmAlgo.Simt, None),
]
......@@ -192,11 +192,13 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [
]
# SHUFFLE_TURING_PARAMS = []
# here we must use f32 for simt f16 accumulators because
# half intristics is VERY SLOW in GTX 1000 series.
IMPLGEMM_SIMT_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (32, 128, 16), (32, 32, 8),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f32,f32,f32,f32,f32", "f16,f16,f16,f16,f16"],
2, ["f32,f32,f32,f32,f32", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
......@@ -208,7 +210,7 @@ IMPLGEMM_SIMT_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (32, 256, 8), (32, 64, 8),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f32,f32,f32,f32,f32", "f16,f16,f16,f16,f16"],
2, ["f32,f32,f32,f32,f32", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
......@@ -220,7 +222,7 @@ IMPLGEMM_SIMT_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (32, 64, 16), (32, 32, 8),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f32,f32,f32,f32,f32", "f16,f16,f16,f16,f16"],
2, ["f32,f32,f32,f32,f32", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
......@@ -232,7 +234,7 @@ IMPLGEMM_SIMT_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (32, 32, 32), (32, 32, 8),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f32,f32,f32,f32,f32", "f16,f16,f16,f16,f16"],
2, ["f32,f32,f32,f32,f32", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
......@@ -244,7 +246,7 @@ IMPLGEMM_SIMT_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (64, 256, 8), (32, 64, 8),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f32,f32,f32,f32,f32", "f16,f16,f16,f16,f16"],
2, ["f32,f32,f32,f32,f32", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
......@@ -256,7 +258,7 @@ IMPLGEMM_SIMT_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 8), (32, 64, 8),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f32,f32,f32,f32,f32", "f16,f16,f16,f16,f16"],
2, ["f32,f32,f32,f32,f32", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
......@@ -268,7 +270,7 @@ IMPLGEMM_SIMT_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 8), (32, 32, 8),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f32,f32,f32,f32,f32", "f16,f16,f16,f16,f16"],
2, ["f32,f32,f32,f32,f32", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
......@@ -280,7 +282,7 @@ IMPLGEMM_SIMT_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (64, 32, 16), (32, 32, 8),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f32,f32,f32,f32,f32", "f16,f16,f16,f16,f16"],
2, ["f32,f32,f32,f32,f32", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
......@@ -404,7 +406,7 @@ IMPLGEMM_VOLTA_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"],
2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
......@@ -416,7 +418,7 @@ IMPLGEMM_VOLTA_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"],
2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
......@@ -428,7 +430,7 @@ IMPLGEMM_VOLTA_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"],
2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
......@@ -440,7 +442,7 @@ IMPLGEMM_VOLTA_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (32, 256, 32), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"],
2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
......@@ -491,7 +493,7 @@ IMPLGEMM_TURING_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (32, 64, 32), (32, 32, 16),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"],
2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
......@@ -503,7 +505,7 @@ IMPLGEMM_TURING_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (32, 64, 32), (32, 32, 16),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"],
2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
......@@ -515,7 +517,7 @@ IMPLGEMM_TURING_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (32, 256, 32), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"],
2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
......@@ -527,7 +529,7 @@ IMPLGEMM_TURING_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (32, 128, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"],
2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
......@@ -539,7 +541,7 @@ IMPLGEMM_TURING_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (32, 128, 64), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"],
2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
......@@ -551,7 +553,7 @@ IMPLGEMM_TURING_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (32, 128, 64), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"],
2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
......@@ -563,7 +565,7 @@ IMPLGEMM_TURING_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (32, 128, 64), (32, 32, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"],
2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
......@@ -575,7 +577,7 @@ IMPLGEMM_TURING_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (32, 128, 64), (32, 64, 64),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"],
2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
......@@ -587,7 +589,7 @@ IMPLGEMM_TURING_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"],
2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
......@@ -599,7 +601,7 @@ IMPLGEMM_TURING_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"],
2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
......@@ -611,7 +613,7 @@ IMPLGEMM_TURING_PARAMS = [
*gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32),
NDIM_DONT_CARE,
ConvIterAlgo.Optimized,
2, ["f16,f16,f16,f16,f16"],
2, ["f16,f16,f16,f16,f16", "f16,f16,f16,f32,f32"],
NHWC,
NHWC,
NHWC,
......
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
class CompileInfo:
@staticmethod
def get_compiled_cuda_arch() -> List[Tuple[int, int]]: ...
......@@ -24,5 +24,7 @@ from spconv.core_cc.csrc.sparse.all import SpconvOps
BUILD_CUMM_VERSION = SpconvOps.cumm_version()
BUILD_PCCM_VERSION = SpconvOps.pccm_version()
from spconv.core_cc.csrc.utils.boxops import BoxOps
from spconv.core_cc.cumm.common import CompileInfo
HAS_BOOST = BoxOps.has_boost()
COMPILED_CUDA_ARCHS = set(CompileInfo.get_compiled_cuda_arch())
......@@ -52,8 +52,6 @@ def expand_nd(val: Union[int, List[int], Tuple[int, ...]], ndim: int) -> List[in
return val
class SparseConvolution(SparseModule):
__constants__ = [
'stride', 'padding', 'dilation', 'groups', 'bias', 'subm', 'inverse',
......@@ -76,6 +74,7 @@ class SparseConvolution(SparseModule):
inverse: bool = False,
indice_key: Optional[str] = None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SparseConvolution, self).__init__(name=name)
assert groups == 1, "don't support groups for now"
......@@ -94,7 +93,9 @@ class SparseConvolution(SparseModule):
if not subm:
self.conv1x1 &= kv_stride == 1
if self.conv1x1:
assert self.padding == [0] * ndim, "padding must be zero for 1x1 conv (k=1,s=1)"
assert self.padding == [
0
] * ndim, "padding must be zero for 1x1 conv (k=1,s=1)"
self.transposed = transposed
self.inverse = inverse
self.output_padding = expand_nd(ndim, output_padding)
......@@ -114,6 +115,7 @@ class SparseConvolution(SparseModule):
if CPU_ONLY_BUILD:
assert algo == ConvAlgo.Native, "cpu only build only support native algorithm"
self.algo = algo
self.fp32_accum = fp32_accum
# self.algo = ConvAlgo.Native
if self.algo == ConvAlgo.Native and not ALL_WEIGHT_IS_KRSC:
if FILTER_HWIO:
......@@ -197,18 +199,25 @@ class SparseConvolution(SparseModule):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
raise ValueError(
"Mode {} not supported, please use one of {}".format(
mode, valid_modes))
fan_in, fan_out = self._calculate_fan_in_and_fan_out()
return fan_in if mode == 'fan_in' else fan_out
def _custom_kaiming_uniform_(self, tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
def _custom_kaiming_uniform_(self,
tensor,
a=0,
mode='fan_in',
nonlinearity='leaky_relu'):
r"""same as torch.init.kaiming_uniform_, with KRSC layout support
"""
fan = self._calculate_correct_fan(mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
bound = math.sqrt(
3.0) * std # Calculate uniform bounds from standard deviation
with torch.no_grad():
return tensor.uniform_(-bound, bound)
......@@ -315,7 +324,8 @@ class SparseConvolution(SparseModule):
indice_pairs = datas.indice_pairs
indice_pair_num = datas.indice_pair_num
assert self.subm, "only support reuse subm indices"
self._check_subm_reuse_valid(input, spatial_shape, datas)
self._check_subm_reuse_valid(input, spatial_shape,
datas)
else:
if input.benchmark:
torch.cuda.synchronize()
......@@ -334,7 +344,7 @@ class SparseConvolution(SparseModule):
msg += f"transpose={self.transposed}"
print(msg, file=sys.stderr)
spconv_save_debug_data(indices)
raise e
raise e
if input.benchmark:
torch.cuda.synchronize()
interval = time.time() - t
......@@ -407,7 +417,8 @@ class SparseConvolution(SparseModule):
mask_argsort_bwd_splits = datas.mask_argsort_bwd_splits
masks = datas.masks
assert self.subm, "only support reuse subm indices"
self._check_subm_reuse_valid(input, spatial_shape, datas)
self._check_subm_reuse_valid(input, spatial_shape,
datas)
else:
with input._timer.namespace("gen_pairs"):
......@@ -437,7 +448,7 @@ class SparseConvolution(SparseModule):
msg += f"transpose={self.transposed}"
print(msg, file=sys.stderr)
spconv_save_debug_data(indices)
raise e
raise e
outids = res[0]
num_inds_per_loc = res[1]
......@@ -479,7 +490,7 @@ class SparseConvolution(SparseModule):
pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits,
num_activate_out, masks, self.training, self.subm,
input._timer)
input._timer, self.fp32_accum)
if self.bias is not None:
out_features += self.bias
if input.benchmark:
......@@ -496,21 +507,28 @@ class SparseConvolution(SparseModule):
out_tensor.spatial_shape = out_spatial_shape
return out_tensor
def _check_subm_reuse_valid(self, inp: SparseConvTensor, spatial_shape: List[int], datas: Union[ImplicitGemmIndiceData, IndiceData]):
def _check_subm_reuse_valid(self, inp: SparseConvTensor,
spatial_shape: List[int],
datas: Union[ImplicitGemmIndiceData,
IndiceData]):
assert datas.is_subm, "only support reuse subm indices"
if self.kernel_size != datas.ksize:
raise ValueError(f"subm with same indice_key must have same kernel"
raise ValueError(
f"subm with same indice_key must have same kernel"
f" size, expect {datas.ksize}, this layer {self.kernel_size}")
if self.dilation != datas.dilation:
raise ValueError(f"subm with same indice_key must have same dilation"
raise ValueError(
f"subm with same indice_key must have same dilation"
f", expect {datas.dilation}, this layer {self.dilation}")
if inp.spatial_shape != datas.spatial_shape:
raise ValueError(f"subm with same indice_key must have same spatial structure"
raise ValueError(
f"subm with same indice_key must have same spatial structure"
f", expect {datas.spatial_shape}, input {spatial_shape}")
if inp.indices.shape[0] != datas.indices.shape[0]:
raise ValueError(f"subm with same indice_key must have same num of indices"
f", expect {datas.indices.shape[0]}, input {inp.indices.shape[0]}")
raise ValueError(
f"subm with same indice_key must have same num of indices"
f", expect {datas.indices.shape[0]}, input {inp.indices.shape[0]}"
)
class SparseConv1d(SparseConvolution):
......@@ -525,6 +543,7 @@ class SparseConv1d(SparseConvolution):
bias=True,
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SparseConv1d, self).__init__(1,
in_channels,
......@@ -537,6 +556,7 @@ class SparseConv1d(SparseConvolution):
bias,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
......@@ -552,6 +572,7 @@ class SparseConv2d(SparseConvolution):
bias=True,
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SparseConv2d, self).__init__(2,
in_channels,
......@@ -564,6 +585,7 @@ class SparseConv2d(SparseConvolution):
bias,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
......@@ -579,6 +601,7 @@ class SparseConv3d(SparseConvolution):
bias=True,
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SparseConv3d, self).__init__(3,
in_channels,
......@@ -591,6 +614,7 @@ class SparseConv3d(SparseConvolution):
bias,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
......@@ -606,6 +630,7 @@ class SparseConv4d(SparseConvolution):
bias=True,
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SparseConv4d, self).__init__(4,
in_channels,
......@@ -618,6 +643,7 @@ class SparseConv4d(SparseConvolution):
bias,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
......@@ -633,6 +659,7 @@ class SparseConvTranspose1d(SparseConvolution):
bias=True,
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SparseConvTranspose1d, self).__init__(1,
in_channels,
......@@ -646,6 +673,7 @@ class SparseConvTranspose1d(SparseConvolution):
transposed=True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
......@@ -661,6 +689,7 @@ class SparseConvTranspose2d(SparseConvolution):
bias=True,
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SparseConvTranspose2d, self).__init__(2,
in_channels,
......@@ -674,6 +703,7 @@ class SparseConvTranspose2d(SparseConvolution):
transposed=True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
......@@ -689,6 +719,7 @@ class SparseConvTranspose3d(SparseConvolution):
bias=True,
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SparseConvTranspose3d, self).__init__(3,
in_channels,
......@@ -702,6 +733,7 @@ class SparseConvTranspose3d(SparseConvolution):
transposed=True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
......@@ -717,6 +749,7 @@ class SparseConvTranspose4d(SparseConvolution):
bias=True,
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SparseConvTranspose4d, self).__init__(4,
in_channels,
......@@ -730,6 +763,7 @@ class SparseConvTranspose4d(SparseConvolution):
transposed=True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
......@@ -741,6 +775,7 @@ class SparseInverseConv1d(SparseConvolution):
indice_key,
bias=True,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SparseInverseConv1d, self).__init__(1,
in_channels,
......@@ -750,6 +785,7 @@ class SparseInverseConv1d(SparseConvolution):
inverse=True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
......@@ -761,6 +797,7 @@ class SparseInverseConv2d(SparseConvolution):
indice_key,
bias=True,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SparseInverseConv2d, self).__init__(2,
in_channels,
......@@ -770,6 +807,7 @@ class SparseInverseConv2d(SparseConvolution):
inverse=True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
......@@ -781,6 +819,7 @@ class SparseInverseConv3d(SparseConvolution):
indice_key,
bias=True,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SparseInverseConv3d, self).__init__(3,
in_channels,
......@@ -790,6 +829,7 @@ class SparseInverseConv3d(SparseConvolution):
inverse=True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
......@@ -801,6 +841,7 @@ class SparseInverseConv4d(SparseConvolution):
indice_key,
bias=True,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SparseInverseConv4d, self).__init__(4,
in_channels,
......@@ -810,6 +851,7 @@ class SparseInverseConv4d(SparseConvolution):
inverse=True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
......@@ -825,6 +867,7 @@ class SubMConv1d(SparseConvolution):
bias=True,
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SubMConv1d, self).__init__(1,
in_channels,
......@@ -838,6 +881,7 @@ class SubMConv1d(SparseConvolution):
True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
......@@ -853,6 +897,7 @@ class SubMConv2d(SparseConvolution):
bias=True,
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SubMConv2d, self).__init__(2,
in_channels,
......@@ -866,6 +911,7 @@ class SubMConv2d(SparseConvolution):
True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
......@@ -881,6 +927,7 @@ class SubMConv3d(SparseConvolution):
bias=True,
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SubMConv3d, self).__init__(3,
in_channels,
......@@ -894,6 +941,7 @@ class SubMConv3d(SparseConvolution):
True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
......@@ -909,6 +957,7 @@ class SubMConv4d(SparseConvolution):
bias=True,
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SubMConv4d, self).__init__(4,
in_channels,
......@@ -922,4 +971,5 @@ class SubMConv4d(SparseConvolution):
True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
......@@ -15,6 +15,8 @@
from cumm import tensorview as tv
import torch
from typing import Optional, List
from spconv.cppconstants import COMPILED_CUDA_ARCHS
import sys
_TORCH_DTYPE_TO_TV = {
torch.float32: tv.float32,
......@@ -63,7 +65,14 @@ def torch_tensors_to_tv(*tens: torch.Tensor):
def get_current_stream():
return torch.cuda.current_stream().cuda_stream
def get_arch():
arch = torch.cuda.get_device_capability()
if arch not in COMPILED_CUDA_ARCHS:
print(f"[WARNING]your gpu arch {arch} isn't compiled in prebuilt, "
f"may cause invalid device function. "
f"available: {COMPILED_CUDA_ARCHS}", file=sys.stderr)
return arch
if __name__ == "__main__":
a = torch.rand(2, 2)
atv = torch_tensor_to_tv(a)
......
......@@ -180,14 +180,16 @@ class SparseImplicitGemmFunction(Function):
masks: List[np.ndarray],
is_train: bool,
is_subm: bool,
timer: CUDAKernelTimer = CUDAKernelTimer(False)):
timer: CUDAKernelTimer = CUDAKernelTimer(False),
fp32_accum: Optional[bool] = None):
try:
out, mask_out, mask_width = ops.implicit_gemm(features, filters,
pair_fwd,
pair_mask_fwd_splits,
mask_argsort_fwd_splits,
num_activate_out, masks,
is_train, is_subm, timer)
is_train, is_subm, timer,
fp32_accum)
except Exception as e:
msg = "[Exception|implicit_gemm]"
msg += f"feat={features.shape},w={filters.shape},pair={pair_fwd.shape},"
......@@ -209,6 +211,7 @@ class SparseImplicitGemmFunction(Function):
# ctx.num_activate_out = num_activate_out
ctx.masks = masks
ctx.is_subm = is_subm
ctx.fp32_accum = fp32_accum
return out
@staticmethod
......@@ -226,6 +229,8 @@ class SparseImplicitGemmFunction(Function):
masks = ctx.masks
is_subm = ctx.is_subm
timer = ctx.timer
fp32_accum = ctx.fp32_accum
try:
input_bp, filters_bp = ops.implicit_gemm_backward(
features,
......@@ -241,7 +246,8 @@ class SparseImplicitGemmFunction(Function):
masks=masks,
mask_width=mask_width,
is_subm=is_subm,
timer=timer)
timer=timer,
fp32_accum=fp32_accum)
except Exception as e:
msg = "[Exception|implicit_gemm_backward]"
msg += f"feat={features.shape},w={filters.shape},pair={pair_fwd.shape},"
......@@ -252,7 +258,7 @@ class SparseImplicitGemmFunction(Function):
masks))
raise e
None_9 = [None] * 11
None_9 = [None] * 12
return (input_bp, filters_bp, *None_9)
......
......@@ -23,7 +23,7 @@ import spconv
from spconv.core import AlgoHint, ConvAlgo
from typing import List, Optional, Union
from spconv.pytorch.core import ThrustSortAllocator
from spconv.pytorch.cppcore import torch_tensor_to_tv, get_current_stream
from spconv.pytorch.cppcore import torch_tensor_to_tv, get_current_stream, get_arch
from spconv.core_cc.csrc.sparse.all import SpconvOps
import spconv.core_cc as _ext
......@@ -685,7 +685,7 @@ def indice_conv(features: torch.Tensor,
profile_idx = i
assert nhot_profile > 0, "this shouldn't happen"
# print(nhot_profile, indice_pair_num_cpu)
arch = torch.cuda.get_device_capability()
arch = get_arch()
tuned_res = GEMM.get_tuned_algo(a.dtype,
filters_tv.dtype,
......@@ -849,7 +849,7 @@ def indice_conv_backward(features: torch.Tensor,
return (din, dfilters.reshape(filters_shape))
maxnhot = max(indice_pair_num_cpu)
arch = torch.cuda.get_device_capability()
arch = get_arch()
filters_tv = torch_tensor_to_tv(filters)
dfilters_tv = torch_tensor_to_tv(dfilters)
......@@ -1097,7 +1097,8 @@ def implicit_gemm(features: torch.Tensor,
masks: List[np.ndarray],
is_train: bool,
is_subm: bool,
timer: CUDAKernelTimer = CUDAKernelTimer(False)):
timer: CUDAKernelTimer = CUDAKernelTimer(False),
fp32_accum: Optional[bool] = None):
stream = get_current_stream()
# if DEBUG:
......@@ -1131,7 +1132,7 @@ def implicit_gemm(features: torch.Tensor,
features_tv = torch_tensor_to_tv(features)
filters_tv = torch_tensor_to_tv(filters)
out_features_tv = torch_tensor_to_tv(out_features)
arch = torch.cuda.get_device_capability()
arch = get_arch()
pair_mask_fwd_split_tvs = [
torch_tensor_to_tv(x, dtype=tv.uint32) for x in pair_mask_fwd_splits
]
......@@ -1159,7 +1160,8 @@ def implicit_gemm(features: torch.Tensor,
indices=pair_fwd_tv,
reverse_mask=False,
mask_filter=masks[0].item(),
stream=stream)
stream=stream,
fp32_accum=fp32_accum)
mask_width = tune_res.algo_desp.tile_shape[0]
if is_train:
mask_output_fwd = torch.empty(
......@@ -1226,7 +1228,8 @@ def implicit_gemm_backward(features: torch.Tensor,
masks: List[np.ndarray],
mask_width: int,
is_subm: bool,
timer: CUDAKernelTimer = CUDAKernelTimer(False)):
timer: CUDAKernelTimer = CUDAKernelTimer(False),
fp32_accum: Optional[bool] = None):
# print(out_bp.mean(), out_bp.max(), out_bp.min())
if features.dtype == torch.int8 or features.dtype == torch.qint8:
raise NotImplementedError("work in progress")
......@@ -1263,7 +1266,7 @@ def implicit_gemm_backward(features: torch.Tensor,
dout_tv = torch_tensor_to_tv(out_bp)
din_tv = torch_tensor_to_tv(din)
mask_output_fwd_tv = torch_tensor_to_tv(mask_output_fwd, dtype=tv.uint32)
arch = torch.cuda.get_device_capability()
arch = get_arch()
pair_mask_fwd_split_tvs = [
torch_tensor_to_tv(x, dtype=tv.uint32) for x in pair_mask_fwd_splits
]
......@@ -1309,7 +1312,8 @@ def implicit_gemm_backward(features: torch.Tensor,
indices=pair_bwd_tv,
reverse_mask=is_subm,
mask_filter=masks[0].item(),
stream=stream)
stream=stream,
fp32_accum=fp32_accum)
if wgrad_tune_res is None:
wgrad_tune_res, _ = CONV.tune_and_cache(
ConvOpType.kBackwardWeight,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment