Commit 19a599e1 authored by yan.yan's avatar yan.yan
Browse files

disable tf32 by default

parent 238d6a83
......@@ -56,6 +56,16 @@ Check [spconv 2.x algorithm introduction](docs/spconv2_algo.pdf) to understand s
* spconv 2.2: ampere feature support (by [EvernightAurora](https://github.com/EvernightAurora)), pure c++ code generation, nvrtc, drop python 3.6
## Spconv 2.2 vs Spconv 2.1
* faster fp16 kernels (~5-30%) in ampere GPUs (tested in RTX 3090)
* greatly faster int8 kernels (~1.2x~2.7x) in ampere GPUs (tested in RTX 3090)
* no python 3.6 support
* nvrtc support: kernel in old GPUs will be compiled in runtime.
* [libspconv](docs/PURE_CPP_BUILD.md): pure c++ build of all spconv ops. see [example](example/libspconv/run_build.sh)
* tf32 kernels, faster fp32 training, disabled by default. set ```import spconv as spconv_core; spconv_core.constants.SPCONV_ALLOW_TF32 = True``` to enable them.
## Spconv 2.1 vs Spconv 1.x
* spconv now can be installed by **pip**. see install section in readme for more details. Users don't need to build manually anymore!
......@@ -66,14 +76,6 @@ Check [spconv 2.x algorithm introduction](docs/spconv2_algo.pdf) to understand s
* [doesn't depend on pytorch binary](docs/FAQ.md#What-does-no-dependency-on-pytorch-mean), but you may need at least pytorch >= 1.5.0 to run spconv 2.x.
* since spconv 2.x doesn't depend on pytorch binary (never in future), it's impossible to support torch.jit/libtorch inference.
## Spconv 2.2 vs Spconv 2.1
* faster fp16 kernels (~5-30%) in ampere GPUs (tested in RTX 3090)
* greatly faster int8 kernels (~1.2x~2.7x) in ampere GPUs (tested in RTX 3090)
* no python 3.6 support
* nvrtc support: kernel in old GPUs will be compiled in runtime.
* [libspconv](docs/PURE_CPP_BUILD.md): pure c++ build of all spconv ops. see [example](example/libspconv/run_build.sh)
## Usage
Firstly you need to use ```import spconv.pytorch as spconv``` in spconv 2.x.
......
......@@ -24,5 +24,5 @@
* Currently fast algorithm only support kernel volume (prod of kernel size) <= 32, so don't use large kernel size.
* make sure your channel size is multiple of 8 when using fp16. multiple of 32 is better.
* spconv 2.x in Windows 10 is 1.5x~2x slower than Linux. use Linux if possible.
* If you train with float32 and ampere or later GPUs, you can set ```spconv.constants.SPCONV_ALLOW_TF32``` to enable faster fp32 training.
See [benchmark](BENCHMARK.md) for more performance details of different algorithms.
......@@ -301,7 +301,8 @@ class SimpleGemm:
trans_b: bool,
trans_c: bool,
arch: Tuple[int, int],
shuffle_type: ShuffleStrideType = ShuffleStrideType.NoShuffle):
shuffle_type: ShuffleStrideType = ShuffleStrideType.NoShuffle,
use_tf32: bool = True):
if trans_c:
trans_a = not trans_a
trans_b = not trans_b
......@@ -327,6 +328,9 @@ class SimpleGemm:
# skip volta tensor op since it is very slow in architectures except volta.
if arch >= (7, 5) and desp.algo == GemmAlgo.Volta.value:
continue
if not use_tf32:
if desp.tensorop[0] > 0 and a.dtype == tv.float32 and b.dtype == tv.float32:
continue
lda = a.stride[0]
ldb = b.stride[0]
ldc = c.stride[0]
......@@ -424,14 +428,15 @@ class SimpleGemm:
gather_data: tv.Tensor = tv.Tensor(),
scatter_data: tv.Tensor = tv.Tensor(),
# mm_func
stream: int = 0):
stream: int = 0,
use_tf32: bool = True):
m, n, k = GemmMainUnitTest.extract_mnk(a.shape, b.shape, trans_a,
trans_b, trans_c,
shuffle_type.value,
a_inds.shape, b_inds.shape,
c_inds.shape)
avail = self.get_all_available(a, b, c, trans_a, trans_b, trans_c,
arch, shuffle_type)
arch, shuffle_type, use_tf32)
# c may be weight, may non-contiguous.
# cumm.tensorview.Tensor don't support non-contiguous clone
c_ = c.clone_whole_storage()
......@@ -660,7 +665,8 @@ class SimpleConv:
arch: Tuple[int, int],
op_type: ConvOpType,
mask_width: int,
fp32_accum: Optional[bool] = None):
fp32_accum: Optional[bool] = None,
use_tf32: bool = True):
avail_algos = get_available_algo_str_from_arch(arch)
finally_algos: List[ConvAlgoDesp] = []
......@@ -692,6 +698,10 @@ class SimpleConv:
# skip volta tensor op since it is very slow in architectures except volta.
if arch >= (7, 5) and desp.algo == GemmAlgo.Volta.value:
continue
if not use_tf32:
if (desp.tensorop[0] > 0 and inp.dtype == tv.float32
and weight.dtype == tv.float32 and out.dtype == tv.float32):
continue
if arch >= (7, 0) and is_fp16:
if desp.algo == GemmAlgo.Simt:
continue
......@@ -796,10 +806,11 @@ class SimpleConv:
alpha: float = 1.0,
beta: float = 0.0,
stream: int = 0,
fp32_accum: Optional[bool] = None):
fp32_accum: Optional[bool] = None,
use_tf32: bool = True):
avail = self.get_all_available(inp, weight, output, layout_i, layout_w,
layout_o, arch, op_type, mask_width,
fp32_accum)
fp32_accum, use_tf32)
inp = inp.clone()
weight = weight.clone()
output = output.clone()
......
......@@ -112,4 +112,7 @@ SPCONV_CPP_GEMM = True
SPCONV_FX_TRACE_MODE = os.getenv("SPCONV_FX_TRACE_MODE", "0") == "1"
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE = 1.1
\ No newline at end of file
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE = 1.1
SPCONV_ALLOW_TF32 = False
......@@ -20,7 +20,7 @@ class ConvTunerSimple:
arch:
"""
...
def get_all_available(self, inp: Tensor, weight: Tensor, out: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], op_type: int, mask_width: int, auto_fp32_accum: bool, fp32_accum: bool) -> List[ConvAlgoDesp]:
def get_all_available(self, inp: Tensor, weight: Tensor, out: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], op_type: int, mask_width: int, auto_fp32_accum: bool, fp32_accum: bool, use_tf32: bool = True) -> List[ConvAlgoDesp]:
"""
Args:
inp:
......@@ -37,6 +37,7 @@ class ConvTunerSimple:
mask_width:
auto_fp32_accum:
fp32_accum:
use_tf32:
"""
...
def cached_get_nvrtc_params(self, desp: ConvAlgoDesp, arch: Tuple[int, int], stream_int: int) -> NVRTCParams:
......@@ -47,7 +48,7 @@ class ConvTunerSimple:
stream_int:
"""
...
def tune_and_cache(self, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], mask: Tensor, mask_argsort: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, mask_output: Tensor = Tensor(), alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, auto_fp32_accum: bool = True, fp32_accum: bool = False, num_run: int = 5) -> Tuple[ConvTuneResult, float]:
def tune_and_cache(self, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], mask: Tensor, mask_argsort: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, mask_output: Tensor = Tensor(), alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, auto_fp32_accum: bool = True, fp32_accum: bool = False, num_run: int = 5, use_tf32: bool = True) -> Tuple[ConvTuneResult, float]:
"""
Args:
op_type:
......@@ -74,6 +75,7 @@ class ConvTunerSimple:
auto_fp32_accum:
fp32_accum:
num_run:
use_tf32:
"""
...
def get_tuned_algo(self, op_type: int, i_dtype: int, w_dtype: int, o_dtype: int, k: int, c: int, arch: Tuple[int, int], mask_width: int = -1) -> Tuple[Any, bool]:
......
......@@ -20,7 +20,7 @@ class GemmTunerSimple:
arch:
"""
...
def get_all_available(self, a: Tensor, b: Tensor, c: Tensor, trans_a: bool, trans_b: bool, trans_c: bool, arch: Tuple[int, int], shuffle_type: int) -> List[GemmAlgoDesp]:
def get_all_available(self, a: Tensor, b: Tensor, c: Tensor, trans_a: bool, trans_b: bool, trans_c: bool, arch: Tuple[int, int], shuffle_type: int, use_tf32: bool = True) -> List[GemmAlgoDesp]:
"""
Args:
a:
......@@ -31,6 +31,7 @@ class GemmTunerSimple:
trans_c:
arch:
shuffle_type:
use_tf32:
"""
...
def cached_get_nvrtc_params(self, desp: GemmAlgoDesp, arch: Tuple[int, int], stream_int: int) -> NVRTCParams:
......@@ -41,7 +42,7 @@ class GemmTunerSimple:
stream_int:
"""
...
def tune_and_cache(self, a: Tensor, b: Tensor, c: Tensor, trans_a: bool, trans_b: bool, trans_c: bool, arch: Tuple[int, int], shuffle_type: int, a_inds: Tensor, b_inds: Tensor, c_inds: Tensor, hint: int = 0, alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, num_run: int = 5) -> Tuple[GemmTuneResult, float]:
def tune_and_cache(self, a: Tensor, b: Tensor, c: Tensor, trans_a: bool, trans_b: bool, trans_c: bool, arch: Tuple[int, int], shuffle_type: int, a_inds: Tensor, b_inds: Tensor, c_inds: Tensor, hint: int = 0, alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, num_run: int = 5, use_tf32: bool = True) -> Tuple[GemmTuneResult, float]:
"""
Args:
a:
......@@ -60,6 +61,7 @@ class GemmTunerSimple:
beta:
stream_int:
num_run:
use_tf32:
"""
...
def get_tuned_algo(self, a_dtype: int, b_dtype: int, c_dtype: int, a_shape: List[int], b_shape: List[int], c_shape: List[int], trans_a: bool, trans_b: bool, trans_c: bool, arch: Tuple[int, int], shuffle_type: int, a_inds_shape: List[int], b_inds_shape: List[int], c_inds_shape: List[int], hint: int = 0) -> Tuple[Any, bool]:
......
......@@ -12,7 +12,7 @@ class ConvGemmOps:
"""
...
@staticmethod
def indice_conv(allocator, ext_mm, gemm_tuner, all_w_is_krsc: bool, filter_hwio: bool, features: Tensor, filters: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor, arch: Tuple[int, int], num_activate_out: int, inverse: bool = False, subm: bool = False, algo: int = 0, stream_int: int = 0, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_) -> None:
def indice_conv(allocator, ext_mm, gemm_tuner, all_w_is_krsc: bool, filter_hwio: bool, features: Tensor, filters: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor, arch: Tuple[int, int], num_activate_out: int, inverse: bool = False, subm: bool = False, algo: int = 0, stream_int: int = 0, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, use_tf32: bool = True) -> None:
"""
1. this function need to take a out features
that from subm first mm.
......@@ -37,10 +37,11 @@ class ConvGemmOps:
act_alpha:
act_beta:
act_type:
use_tf32:
"""
...
@staticmethod
def indice_conv_backward(allocator, ext_mm, gemm_tuner, all_w_is_krsc: bool, filter_hwio: bool, features: Tensor, filters: Tensor, out_bp: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor, arch: Tuple[int, int], inverse: bool = False, subm: bool = False, algo: int = 0, stream_int: int = 0) -> None:
def indice_conv_backward(allocator, ext_mm, gemm_tuner, all_w_is_krsc: bool, filter_hwio: bool, features: Tensor, filters: Tensor, out_bp: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor, arch: Tuple[int, int], inverse: bool = False, subm: bool = False, algo: int = 0, stream_int: int = 0, use_tf32: bool = True) -> None:
"""
Args:
allocator:
......@@ -58,10 +59,11 @@ class ConvGemmOps:
subm:
algo:
stream_int:
use_tf32:
"""
...
@staticmethod
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_) -> Tuple[int, Any]:
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, use_tf32: bool = True) -> Tuple[int, Any]:
"""
Args:
allocator:
......@@ -84,10 +86,11 @@ class ConvGemmOps:
act_alpha:
act_beta:
act_type:
use_tf32:
"""
...
@staticmethod
def implicit_gemm_backward(allocator, conv_tuner, features: Tensor, filters: Tensor, out_bp: Tensor, pair_fwd: Tensor, pair_bwd: Tensor, pair_mask_fwd_splits: List[Tensor], pair_mask_bwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], mask_argsort_bwd_splits: List[Tensor], mask_output_fwd: Tensor, masks: Tensor, arch: Tuple[int, int], mask_width: int, is_subm: bool, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False) -> None:
def implicit_gemm_backward(allocator, conv_tuner, features: Tensor, filters: Tensor, out_bp: Tensor, pair_fwd: Tensor, pair_bwd: Tensor, pair_mask_fwd_splits: List[Tensor], pair_mask_bwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], mask_argsort_bwd_splits: List[Tensor], mask_output_fwd: Tensor, masks: Tensor, arch: Tuple[int, int], mask_width: int, is_subm: bool, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, use_tf32: bool = True) -> None:
"""
Args:
allocator:
......@@ -110,5 +113,6 @@ class ConvGemmOps:
timer:
auto_fp32_accum:
fp32_accum:
use_tf32:
"""
...
......@@ -538,6 +538,8 @@ class GemmTunerSimple(pccm.ParameterizedClass):
code.arg("trans_a, trans_b, trans_c", "bool")
code.arg("arch", "std::tuple<int, int>")
code.arg("shuffle_type", "int")
code.arg("use_tf32", "bool", "true")
code.raw(f"""
if (trans_c){{
trans_a = !trans_a;
......@@ -562,6 +564,12 @@ class GemmTunerSimple(pccm.ParameterizedClass):
if (arch >= std::make_tuple(7, 5) && desp.algo == {pccm.literal(GemmAlgo.Volta.value)}){{
continue;
}}
if (!use_tf32){{
if (desp.tensorop[0] > 0 && a.dtype() == tv::float32 && b.dtype() == tv::float32){{
// tf32 op
continue;
}}
}}
auto lda = a.stride(0);
auto ldb = b.stride(0);
auto ldc = c.stride(0);
......@@ -656,6 +664,8 @@ class GemmTunerSimple(pccm.ParameterizedClass):
code.arg("beta", "float", "0.0")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.arg("num_run", "int", "5")
code.arg("use_tf32", "bool", "true")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
code.raw("return std::make_tuple(GemmTuneResult(), -1.0f);")
......@@ -677,8 +687,8 @@ class GemmTunerSimple(pccm.ParameterizedClass):
auto n = std::get<1>(mnk);
auto k = std::get<2>(mnk);
auto avail = get_all_available(a, b, c, trans_a,
trans_b, trans_c, arch, shuffle_type);
auto avail = get_all_available(a, b, c, trans_a, trans_b,
trans_c, arch, shuffle_type, use_tf32);
auto c_ = c.clone_whole_storage();
std::vector<GemmTuneResult> all_profile_res;
std::vector<int> splitk_tests;
......@@ -969,6 +979,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code.arg("mask_width", "int")
code.arg("auto_fp32_accum", "bool")
code.arg("fp32_accum", "bool")
code.arg("use_tf32", "bool", "true")
code.raw(f"""
tv::gemm::ConvOpType op_type_cpp = static_cast<tv::gemm::ConvOpType>(op_type);
......@@ -1010,6 +1021,12 @@ class ConvTunerSimple(pccm.ParameterizedClass):
if (arch >= std::make_tuple(7, 5) && desp.algo == {pccm.literal(GemmAlgo.Volta.value)}){{
continue;
}}
if (!use_tf32){{
if (desp.tensorop[0] > 0 && inp.dtype() == tv::float32 && weight.dtype() == tv::float32 && out.dtype() == tv::float32){{
// tf32 op
continue;
}}
}}
if (arch >= std::make_tuple(7, 0) && is_fp16){{
// skip simt fp16 kernels if we have tensor core
if (desp.algo == {pccm.literal(GemmAlgo.Simt.value)}){{
......@@ -1086,6 +1103,8 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code.arg("auto_fp32_accum", "bool", "true")
code.arg("fp32_accum", "bool", "false")
code.arg("num_run", "int", "5")
code.arg("use_tf32", "bool", "true")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
return code.ret(
......@@ -1099,7 +1118,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
auto avail = get_all_available(inp, weight, output, layout_i, layout_w,
layout_o, interleave_i, interleave_w, interleave_o,
arch, op_type, mask_width,
auto_fp32_accum, fp32_accum);
auto_fp32_accum, fp32_accum, use_tf32);
inp = inp.clone();
weight = weight.clone();
output = output.clone();
......@@ -1408,6 +1427,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
code.arg("act_alpha", f"float", "0.0")
code.arg("act_beta", f"float", "0.0")
code.arg("act_type", f"tv::gemm::Activation", "tv::gemm::Activation::kNone", "cumm.tensorview.gemm.Activation = Activation.None_")
code.arg("use_tf32", "bool", "true")
code.raw(f"""
int kv_dim, out_channel, kv;
......@@ -1571,7 +1591,9 @@ class ConvGemmOps(pccm.ParameterizedClass):
{AlgoHint.Fowrard.value},
1.0,
0.0,
stream_int);
stream_int,
5, // num_run
use_tf32);
tune_res = std::get<0>(tune_res_time);
}}
......@@ -1640,6 +1662,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
code.arg("subm", "bool", "false")
code.arg("algo", "int", f"{ConvAlgo.Native.value}")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.arg("use_tf32", "bool", "true")
code.raw(f"""
int kv_dim, out_channel, kv;
......@@ -1794,7 +1817,9 @@ class ConvGemmOps(pccm.ParameterizedClass):
{AlgoHint.BackwardInput.value},
1.0,
0.0,
stream_int);
stream_int,
5, // num_run
use_tf32);
tuned_res_dgrad = std::get<0>(tune_res_time);
}}
tv::Tensor a_wgrad, b_wgrad;
......@@ -1852,7 +1877,9 @@ class ConvGemmOps(pccm.ParameterizedClass):
{AlgoHint.BackwardWeight.value},
1.0,
0.0,
stream_int);
stream_int,
5, // num_run
use_tf32);
tuned_res_wgrad = std::get<0>(tune_res_time);
}}
......@@ -1966,6 +1993,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
code.arg("act_alpha", f"float", "0.0")
code.arg("act_beta", f"float", "0.0")
code.arg("act_type", f"tv::gemm::Activation", "tv::gemm::Activation::kNone", "cumm.tensorview.gemm.Activation = Activation.None_")
code.arg("use_tf32", "bool", "true")
if CUMM_CPU_ONLY_BUILD:
......@@ -2025,7 +2053,9 @@ class ConvGemmOps(pccm.ParameterizedClass):
1.0, 0.0,
stream_int,
auto_fp32_accum,
fp32_accum);
fp32_accum,
5, // num_run
use_tf32);
tune_res = std::get<0>(tune_res_time);
}}
......@@ -2109,6 +2139,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
"cumm.tensorview.CUDAKernelTimer = CUDAKernelTimer(False)")
code.arg("auto_fp32_accum", "bool", "true")
code.arg("fp32_accum", "bool", "false")
code.arg("use_tf32", "bool", "true")
if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
......@@ -2192,7 +2223,9 @@ class ConvGemmOps(pccm.ParameterizedClass):
1.0, 0.0,
stream_int,
auto_fp32_accum,
fp32_accum);
fp32_accum,
5, // num_run
use_tf32);
dgrad_tune_res = std::get<0>(tune_res_time);
}}
if (!wgrad_exists){{
......@@ -2214,7 +2247,9 @@ class ConvGemmOps(pccm.ParameterizedClass):
1.0, 0.0,
stream_int,
auto_fp32_accum,
fp32_accum);
fp32_accum,
5, // num_run
use_tf32);
wgrad_tune_res = std::get<0>(tune_res_time);
}}
int ws_size = conv_tuner.query_workspace_size(wgrad_tune_res.algo_desp,
......
......@@ -26,7 +26,7 @@ from spconv.pytorch.core import ThrustSortAllocator
from spconv.pytorch.cppcore import TorchAllocator, torch_tensor_to_tv, get_current_stream, get_arch, TorchSpconvMatmul
from spconv.core_cc.csrc.sparse.all import SpconvOps
from spconv.core_cc.csrc.sparse.alloc import ExternalAllocator
from spconv.constants import SPCONV_CPP_INDICE_PAIRS, SPCONV_CPP_INDICE_PAIRS_IGEMM, SPCONV_CPP_GEMM, SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE
from spconv.constants import SPCONV_CPP_INDICE_PAIRS, SPCONV_CPP_INDICE_PAIRS_IGEMM, SPCONV_CPP_GEMM, SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE, SPCONV_ALLOW_TF32
import spconv.core_cc as _ext
from spconv.core_cc.csrc.sparse.convops.spops import ConvGemmOps
from spconv.core_cc.csrc.sparse.inference import InferenceOps
......@@ -831,7 +831,8 @@ def indice_conv(features: torch.Tensor,
FILTER_HWIO, features_tv, filters_tv,
indice_pairs_tv, indice_pair_num_tv, arch,
num_activate_out, inverse, subm, algo.value,
stream, bias_tv, act_alpha, act_beta, act_type)
stream, bias_tv, act_alpha, act_beta, act_type,
use_tf32=SPCONV_ALLOW_TF32)
out_features = alloc.allocated[AllocKeys.OutFeatures]
return out_features
if not features.is_cuda:
......@@ -1011,7 +1012,8 @@ def indice_conv(features: torch.Tensor,
alpha=1.0,
beta=0.0,
hint=AlgoHint.Fowrard.value,
stream=stream)
stream=stream,
use_tf32=SPCONV_ALLOW_TF32)
# CONV.stream_synchronize(stream)
# t = time.time()
with timer.record("forward", stream):
......@@ -1103,7 +1105,7 @@ def indice_conv_backward(features: torch.Tensor,
features_tv, filters_tv, out_bp_tv,
indice_pairs_tv, indice_pair_num_tv,
arch, inverse, subm, algo.value,
stream)
stream, use_tf32=SPCONV_ALLOW_TF32)
din = alloc.allocated[AllocKeys.DIn]
df = alloc.allocated[AllocKeys.DFilters]
return din, df
......@@ -1270,7 +1272,8 @@ def indice_conv_backward(features: torch.Tensor,
alpha=1.0,
beta=0.0,
hint=AlgoHint.BackwardInput.value,
stream=stream)
stream=stream,
use_tf32=SPCONV_ALLOW_TF32)
if is_KC_not_CK:
a_wgrad = out_bp_tv
b_wgrad = features_tv
......@@ -1317,7 +1320,8 @@ def indice_conv_backward(features: torch.Tensor,
alpha=1.0,
beta=0.0,
hint=AlgoHint.BackwardWeight.value,
stream=stream)
stream=stream,
use_tf32=SPCONV_ALLOW_TF32)
# print(tuned_res_wgrad.algo_desp, tuned_res_wgrad.splitk, min_time)
# get workspace size for wgrad
if is_KC_not_CK:
......@@ -1462,7 +1466,8 @@ def implicit_gemm(features: torch.Tensor,
alloc, CONV_CPP, features_tv, filters_tv, pair_fwd_tv,
pair_mask_fwd_splits_tv, mask_argsort_fwd_splits_tv,
num_activate_out, mask_tv, arch, is_train, is_subm, stream,
timer_cpp, auto_fp32_accum, fp32_accum, bias_tv, act_alpha, act_beta, act_type)
timer_cpp, auto_fp32_accum, fp32_accum, bias_tv, act_alpha, act_beta, act_type,
use_tf32=SPCONV_ALLOW_TF32)
out_features = alloc.allocated[AllocKeys.OutFeatures]
mask_output_fwd = alloc.allocated.get(AllocKeys.MaskOutputFwd, None)
if is_train:
......@@ -1529,7 +1534,8 @@ def implicit_gemm(features: torch.Tensor,
reverse_mask=False,
mask_filter=masks[0].item(),
stream=stream,
fp32_accum=fp32_accum)
fp32_accum=fp32_accum,
use_tf32=SPCONV_ALLOW_TF32)
mask_width = tune_res.algo_desp.tile_shape[0]
if is_train:
mask_output_fwd = torch.empty(
......@@ -1741,7 +1747,8 @@ def implicit_gemm_backward(features: torch.Tensor,
pair_bwd_tv, pair_mask_fwd_splits_tv, pair_mask_bwd_splits_tv,
mask_argsort_fwd_splits_tv, mask_argsort_bwd_splits_tv,
mask_output_fwd_tv, mask_tv, arch, mask_width, is_subm, stream,
timer_cpp, auto_fp32_accum, fp32_accum)
timer_cpp, auto_fp32_accum, fp32_accum,
use_tf32=SPCONV_ALLOW_TF32)
din = alloc.allocated[AllocKeys.DIn]
dfilters = alloc.allocated[AllocKeys.DFilters]
return din, dfilters
......@@ -1817,7 +1824,8 @@ def implicit_gemm_backward(features: torch.Tensor,
reverse_mask=is_subm,
mask_filter=masks[0].item(),
stream=stream,
fp32_accum=fp32_accum)
fp32_accum=fp32_accum,
use_tf32=SPCONV_ALLOW_TF32)
if wgrad_tune_res is None:
wgrad_tune_res, _ = CONV.tune_and_cache(
ConvOpType.kBackwardWeight,
......@@ -1835,7 +1843,8 @@ def implicit_gemm_backward(features: torch.Tensor,
mask_filter=masks[0].item(),
mask_output=tv.Tensor(),
mask_width=mask_width,
stream=stream)
stream=stream,
use_tf32=SPCONV_ALLOW_TF32)
workspace_size = CONV.query_workspace_size(wgrad_tune_res.algo_desp,
wgrad_tune_res.splitk,
ConvOpType.kBackwardWeight,
......
......@@ -34,6 +34,7 @@ from spconv.core_cc.csrc.sparse.convops import GemmTuneResult, ConvTuneResult
from spconv.pytorch.core import SparseConvTensor
from spconv.test_utils import TestCase
from cumm import tensorview as tv
from spconv.constants import SPCONV_ALLOW_TF32
from cumm.conv.bases import NCHW, NHWC, ConvIterAlgo, ConvOpType
import os
from cumm.gemm.codeops import div_up
......@@ -279,6 +280,8 @@ def _test_impgemm_conv_cuda(subm: bool):
shapes = [[19, 18, 17]]
batchsizes = [1]
dtypes = [np.float32, np.float16]
# dtypes = [np.float16]
# dtypes = [np.int8]
test_case = TestCase()
# in_channels = [32]
......@@ -331,9 +334,11 @@ def _test_impgemm_conv_cuda(subm: bool):
if SPCONV_CPP_GEMM:
avail_desps = CONV_CPP.get_all_available(inp_tv, weight_tv, output_tv,
NHWC.layout_type.value, NHWC.layout_type.value,
NHWC.layout_type.value, NHWC.interleave, NHWC.interleave, NHWC.interleave, arch, op_type.value, -1, True, False)
NHWC.layout_type.value, NHWC.interleave, NHWC.interleave, NHWC.interleave, arch, op_type.value, -1, True, False,
use_tf32=SPCONV_ALLOW_TF32)
else:
avail_desps = CONV.get_all_available(inp_tv, weight_tv, output_tv, NHWC, NHWC, NHWC, arch, op_type, -1)
avail_desps = CONV.get_all_available(inp_tv, weight_tv, output_tv, NHWC, NHWC, NHWC, arch, op_type, -1,
use_tf32=SPCONV_ALLOW_TF32)
if op_type == ConvOpType.kForward and tester.check_act:
act = tv.gemm.Activation.ReLU
else:
......@@ -535,9 +540,11 @@ def _test_impgemm_conv_cuda(subm: bool):
avail_desps = CONV_CPP.get_all_available(inp_tv, weight_tv, output_tv,
NHWC.layout_type.value, NHWC.layout_type.value,
NHWC.layout_type.value, NHWC.interleave, NHWC.interleave, NHWC.interleave, arch,
ConvOpType.kBackwardWeight.value, mask_width, True, False)
ConvOpType.kBackwardWeight.value, mask_width, True, False,
use_tf32=SPCONV_ALLOW_TF32)
else:
avail_desps = CONV.get_all_available(inp_tv, weight_tv, output_tv, NHWC, NHWC, NHWC, arch, ConvOpType.kBackwardWeight, mask_width)
avail_desps = CONV.get_all_available(inp_tv, weight_tv, output_tv, NHWC, NHWC, NHWC, arch, ConvOpType.kBackwardWeight, mask_width,
use_tf32=SPCONV_ALLOW_TF32)
for desp in avail_desps:
weight_tv.zero_()
if subm:
......@@ -753,9 +760,11 @@ def _test_native_conv_cuda(subm: bool):
b = weight_tv.select(1, tester.kv // 2)
c = inp_tv
if SPCONV_CPP_GEMM:
avail_desps = GEMM_CPP.get_all_available(a, b, c, False, False, False, arch, ShuffleStrideType.ShuffleAC.value)
avail_desps = GEMM_CPP.get_all_available(a, b, c, False, False, False, arch, ShuffleStrideType.ShuffleAC.value,
use_tf32=SPCONV_ALLOW_TF32)
else:
avail_desps = GEMM.get_all_available(a, b, c, False, False, False, arch, ShuffleStrideType.ShuffleAC)
avail_desps = GEMM.get_all_available(a, b, c, False, False, False, arch, ShuffleStrideType.ShuffleAC,
use_tf32=SPCONV_ALLOW_TF32)
for desp in avail_desps:
if subm:
......@@ -827,9 +836,11 @@ def _test_native_conv_cuda(subm: bool):
b = inp_tv
c = weight_tv.select(1, tester.kv // 2)
if SPCONV_CPP_GEMM:
avail_desps = GEMM_CPP.get_all_available(a, b, c, True, False, False, arch, ShuffleStrideType.ShuffleAB.value)
avail_desps = GEMM_CPP.get_all_available(a, b, c, True, False, False, arch, ShuffleStrideType.ShuffleAB.value,
use_tf32=SPCONV_ALLOW_TF32)
else:
avail_desps = GEMM.get_all_available(a, b, c, True, False, False, arch, ShuffleStrideType.ShuffleAB)
avail_desps = GEMM.get_all_available(a, b, c, True, False, False, arch, ShuffleStrideType.ShuffleAB,
use_tf32=SPCONV_ALLOW_TF32)
for desp in avail_desps:
# print(desp, C, K, k, s, p, d)
......@@ -900,8 +911,8 @@ def test_all_algo_unit():
# for i in range(5):
_test_impgemm_conv_cuda(True)
_test_impgemm_conv_cuda(False)
# _test_native_conv_cuda(True)
# _test_native_conv_cuda(False)
_test_native_conv_cuda(True)
_test_native_conv_cuda(False)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment