Commit 19a599e1 authored by yan.yan's avatar yan.yan
Browse files

disable tf32 by default

parent 238d6a83
...@@ -56,6 +56,16 @@ Check [spconv 2.x algorithm introduction](docs/spconv2_algo.pdf) to understand s ...@@ -56,6 +56,16 @@ Check [spconv 2.x algorithm introduction](docs/spconv2_algo.pdf) to understand s
* spconv 2.2: ampere feature support (by [EvernightAurora](https://github.com/EvernightAurora)), pure c++ code generation, nvrtc, drop python 3.6 * spconv 2.2: ampere feature support (by [EvernightAurora](https://github.com/EvernightAurora)), pure c++ code generation, nvrtc, drop python 3.6
## Spconv 2.2 vs Spconv 2.1
* faster fp16 kernels (~5-30%) in ampere GPUs (tested in RTX 3090)
* greatly faster int8 kernels (~1.2x~2.7x) in ampere GPUs (tested in RTX 3090)
* no python 3.6 support
* nvrtc support: kernel in old GPUs will be compiled in runtime.
* [libspconv](docs/PURE_CPP_BUILD.md): pure c++ build of all spconv ops. see [example](example/libspconv/run_build.sh)
* tf32 kernels, faster fp32 training, disabled by default. set ```import spconv as spconv_core; spconv_core.constants.SPCONV_ALLOW_TF32 = True``` to enable them.
## Spconv 2.1 vs Spconv 1.x ## Spconv 2.1 vs Spconv 1.x
* spconv now can be installed by **pip**. see install section in readme for more details. Users don't need to build manually anymore! * spconv now can be installed by **pip**. see install section in readme for more details. Users don't need to build manually anymore!
...@@ -66,14 +76,6 @@ Check [spconv 2.x algorithm introduction](docs/spconv2_algo.pdf) to understand s ...@@ -66,14 +76,6 @@ Check [spconv 2.x algorithm introduction](docs/spconv2_algo.pdf) to understand s
* [doesn't depend on pytorch binary](docs/FAQ.md#What-does-no-dependency-on-pytorch-mean), but you may need at least pytorch >= 1.5.0 to run spconv 2.x. * [doesn't depend on pytorch binary](docs/FAQ.md#What-does-no-dependency-on-pytorch-mean), but you may need at least pytorch >= 1.5.0 to run spconv 2.x.
* since spconv 2.x doesn't depend on pytorch binary (never in future), it's impossible to support torch.jit/libtorch inference. * since spconv 2.x doesn't depend on pytorch binary (never in future), it's impossible to support torch.jit/libtorch inference.
## Spconv 2.2 vs Spconv 2.1
* faster fp16 kernels (~5-30%) in ampere GPUs (tested in RTX 3090)
* greatly faster int8 kernels (~1.2x~2.7x) in ampere GPUs (tested in RTX 3090)
* no python 3.6 support
* nvrtc support: kernel in old GPUs will be compiled in runtime.
* [libspconv](docs/PURE_CPP_BUILD.md): pure c++ build of all spconv ops. see [example](example/libspconv/run_build.sh)
## Usage ## Usage
Firstly you need to use ```import spconv.pytorch as spconv``` in spconv 2.x. Firstly you need to use ```import spconv.pytorch as spconv``` in spconv 2.x.
......
...@@ -24,5 +24,5 @@ ...@@ -24,5 +24,5 @@
* Currently fast algorithm only support kernel volume (prod of kernel size) <= 32, so don't use large kernel size. * Currently fast algorithm only support kernel volume (prod of kernel size) <= 32, so don't use large kernel size.
* make sure your channel size is multiple of 8 when using fp16. multiple of 32 is better. * make sure your channel size is multiple of 8 when using fp16. multiple of 32 is better.
* spconv 2.x in Windows 10 is 1.5x~2x slower than Linux. use Linux if possible. * spconv 2.x in Windows 10 is 1.5x~2x slower than Linux. use Linux if possible.
* If you train with float32 and ampere or later GPUs, you can set ```spconv.constants.SPCONV_ALLOW_TF32``` to enable faster fp32 training.
See [benchmark](BENCHMARK.md) for more performance details of different algorithms. See [benchmark](BENCHMARK.md) for more performance details of different algorithms.
...@@ -301,7 +301,8 @@ class SimpleGemm: ...@@ -301,7 +301,8 @@ class SimpleGemm:
trans_b: bool, trans_b: bool,
trans_c: bool, trans_c: bool,
arch: Tuple[int, int], arch: Tuple[int, int],
shuffle_type: ShuffleStrideType = ShuffleStrideType.NoShuffle): shuffle_type: ShuffleStrideType = ShuffleStrideType.NoShuffle,
use_tf32: bool = True):
if trans_c: if trans_c:
trans_a = not trans_a trans_a = not trans_a
trans_b = not trans_b trans_b = not trans_b
...@@ -327,6 +328,9 @@ class SimpleGemm: ...@@ -327,6 +328,9 @@ class SimpleGemm:
# skip volta tensor op since it is very slow in architectures except volta. # skip volta tensor op since it is very slow in architectures except volta.
if arch >= (7, 5) and desp.algo == GemmAlgo.Volta.value: if arch >= (7, 5) and desp.algo == GemmAlgo.Volta.value:
continue continue
if not use_tf32:
if desp.tensorop[0] > 0 and a.dtype == tv.float32 and b.dtype == tv.float32:
continue
lda = a.stride[0] lda = a.stride[0]
ldb = b.stride[0] ldb = b.stride[0]
ldc = c.stride[0] ldc = c.stride[0]
...@@ -424,14 +428,15 @@ class SimpleGemm: ...@@ -424,14 +428,15 @@ class SimpleGemm:
gather_data: tv.Tensor = tv.Tensor(), gather_data: tv.Tensor = tv.Tensor(),
scatter_data: tv.Tensor = tv.Tensor(), scatter_data: tv.Tensor = tv.Tensor(),
# mm_func # mm_func
stream: int = 0): stream: int = 0,
use_tf32: bool = True):
m, n, k = GemmMainUnitTest.extract_mnk(a.shape, b.shape, trans_a, m, n, k = GemmMainUnitTest.extract_mnk(a.shape, b.shape, trans_a,
trans_b, trans_c, trans_b, trans_c,
shuffle_type.value, shuffle_type.value,
a_inds.shape, b_inds.shape, a_inds.shape, b_inds.shape,
c_inds.shape) c_inds.shape)
avail = self.get_all_available(a, b, c, trans_a, trans_b, trans_c, avail = self.get_all_available(a, b, c, trans_a, trans_b, trans_c,
arch, shuffle_type) arch, shuffle_type, use_tf32)
# c may be weight, may non-contiguous. # c may be weight, may non-contiguous.
# cumm.tensorview.Tensor don't support non-contiguous clone # cumm.tensorview.Tensor don't support non-contiguous clone
c_ = c.clone_whole_storage() c_ = c.clone_whole_storage()
...@@ -660,7 +665,8 @@ class SimpleConv: ...@@ -660,7 +665,8 @@ class SimpleConv:
arch: Tuple[int, int], arch: Tuple[int, int],
op_type: ConvOpType, op_type: ConvOpType,
mask_width: int, mask_width: int,
fp32_accum: Optional[bool] = None): fp32_accum: Optional[bool] = None,
use_tf32: bool = True):
avail_algos = get_available_algo_str_from_arch(arch) avail_algos = get_available_algo_str_from_arch(arch)
finally_algos: List[ConvAlgoDesp] = [] finally_algos: List[ConvAlgoDesp] = []
...@@ -692,6 +698,10 @@ class SimpleConv: ...@@ -692,6 +698,10 @@ class SimpleConv:
# skip volta tensor op since it is very slow in architectures except volta. # skip volta tensor op since it is very slow in architectures except volta.
if arch >= (7, 5) and desp.algo == GemmAlgo.Volta.value: if arch >= (7, 5) and desp.algo == GemmAlgo.Volta.value:
continue continue
if not use_tf32:
if (desp.tensorop[0] > 0 and inp.dtype == tv.float32
and weight.dtype == tv.float32 and out.dtype == tv.float32):
continue
if arch >= (7, 0) and is_fp16: if arch >= (7, 0) and is_fp16:
if desp.algo == GemmAlgo.Simt: if desp.algo == GemmAlgo.Simt:
continue continue
...@@ -796,10 +806,11 @@ class SimpleConv: ...@@ -796,10 +806,11 @@ class SimpleConv:
alpha: float = 1.0, alpha: float = 1.0,
beta: float = 0.0, beta: float = 0.0,
stream: int = 0, stream: int = 0,
fp32_accum: Optional[bool] = None): fp32_accum: Optional[bool] = None,
use_tf32: bool = True):
avail = self.get_all_available(inp, weight, output, layout_i, layout_w, avail = self.get_all_available(inp, weight, output, layout_i, layout_w,
layout_o, arch, op_type, mask_width, layout_o, arch, op_type, mask_width,
fp32_accum) fp32_accum, use_tf32)
inp = inp.clone() inp = inp.clone()
weight = weight.clone() weight = weight.clone()
output = output.clone() output = output.clone()
......
...@@ -113,3 +113,6 @@ SPCONV_FX_TRACE_MODE = os.getenv("SPCONV_FX_TRACE_MODE", "0") == "1" ...@@ -113,3 +113,6 @@ SPCONV_FX_TRACE_MODE = os.getenv("SPCONV_FX_TRACE_MODE", "0") == "1"
SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE = 1.1 SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE = 1.1
SPCONV_ALLOW_TF32 = False
...@@ -20,7 +20,7 @@ class ConvTunerSimple: ...@@ -20,7 +20,7 @@ class ConvTunerSimple:
arch: arch:
""" """
... ...
def get_all_available(self, inp: Tensor, weight: Tensor, out: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], op_type: int, mask_width: int, auto_fp32_accum: bool, fp32_accum: bool) -> List[ConvAlgoDesp]: def get_all_available(self, inp: Tensor, weight: Tensor, out: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], op_type: int, mask_width: int, auto_fp32_accum: bool, fp32_accum: bool, use_tf32: bool = True) -> List[ConvAlgoDesp]:
""" """
Args: Args:
inp: inp:
...@@ -37,6 +37,7 @@ class ConvTunerSimple: ...@@ -37,6 +37,7 @@ class ConvTunerSimple:
mask_width: mask_width:
auto_fp32_accum: auto_fp32_accum:
fp32_accum: fp32_accum:
use_tf32:
""" """
... ...
def cached_get_nvrtc_params(self, desp: ConvAlgoDesp, arch: Tuple[int, int], stream_int: int) -> NVRTCParams: def cached_get_nvrtc_params(self, desp: ConvAlgoDesp, arch: Tuple[int, int], stream_int: int) -> NVRTCParams:
...@@ -47,7 +48,7 @@ class ConvTunerSimple: ...@@ -47,7 +48,7 @@ class ConvTunerSimple:
stream_int: stream_int:
""" """
... ...
def tune_and_cache(self, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], mask: Tensor, mask_argsort: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, mask_output: Tensor = Tensor(), alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, auto_fp32_accum: bool = True, fp32_accum: bool = False, num_run: int = 5) -> Tuple[ConvTuneResult, float]: def tune_and_cache(self, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], mask: Tensor, mask_argsort: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, mask_output: Tensor = Tensor(), alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, auto_fp32_accum: bool = True, fp32_accum: bool = False, num_run: int = 5, use_tf32: bool = True) -> Tuple[ConvTuneResult, float]:
""" """
Args: Args:
op_type: op_type:
...@@ -74,6 +75,7 @@ class ConvTunerSimple: ...@@ -74,6 +75,7 @@ class ConvTunerSimple:
auto_fp32_accum: auto_fp32_accum:
fp32_accum: fp32_accum:
num_run: num_run:
use_tf32:
""" """
... ...
def get_tuned_algo(self, op_type: int, i_dtype: int, w_dtype: int, o_dtype: int, k: int, c: int, arch: Tuple[int, int], mask_width: int = -1) -> Tuple[Any, bool]: def get_tuned_algo(self, op_type: int, i_dtype: int, w_dtype: int, o_dtype: int, k: int, c: int, arch: Tuple[int, int], mask_width: int = -1) -> Tuple[Any, bool]:
......
...@@ -20,7 +20,7 @@ class GemmTunerSimple: ...@@ -20,7 +20,7 @@ class GemmTunerSimple:
arch: arch:
""" """
... ...
def get_all_available(self, a: Tensor, b: Tensor, c: Tensor, trans_a: bool, trans_b: bool, trans_c: bool, arch: Tuple[int, int], shuffle_type: int) -> List[GemmAlgoDesp]: def get_all_available(self, a: Tensor, b: Tensor, c: Tensor, trans_a: bool, trans_b: bool, trans_c: bool, arch: Tuple[int, int], shuffle_type: int, use_tf32: bool = True) -> List[GemmAlgoDesp]:
""" """
Args: Args:
a: a:
...@@ -31,6 +31,7 @@ class GemmTunerSimple: ...@@ -31,6 +31,7 @@ class GemmTunerSimple:
trans_c: trans_c:
arch: arch:
shuffle_type: shuffle_type:
use_tf32:
""" """
... ...
def cached_get_nvrtc_params(self, desp: GemmAlgoDesp, arch: Tuple[int, int], stream_int: int) -> NVRTCParams: def cached_get_nvrtc_params(self, desp: GemmAlgoDesp, arch: Tuple[int, int], stream_int: int) -> NVRTCParams:
...@@ -41,7 +42,7 @@ class GemmTunerSimple: ...@@ -41,7 +42,7 @@ class GemmTunerSimple:
stream_int: stream_int:
""" """
... ...
def tune_and_cache(self, a: Tensor, b: Tensor, c: Tensor, trans_a: bool, trans_b: bool, trans_c: bool, arch: Tuple[int, int], shuffle_type: int, a_inds: Tensor, b_inds: Tensor, c_inds: Tensor, hint: int = 0, alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, num_run: int = 5) -> Tuple[GemmTuneResult, float]: def tune_and_cache(self, a: Tensor, b: Tensor, c: Tensor, trans_a: bool, trans_b: bool, trans_c: bool, arch: Tuple[int, int], shuffle_type: int, a_inds: Tensor, b_inds: Tensor, c_inds: Tensor, hint: int = 0, alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, num_run: int = 5, use_tf32: bool = True) -> Tuple[GemmTuneResult, float]:
""" """
Args: Args:
a: a:
...@@ -60,6 +61,7 @@ class GemmTunerSimple: ...@@ -60,6 +61,7 @@ class GemmTunerSimple:
beta: beta:
stream_int: stream_int:
num_run: num_run:
use_tf32:
""" """
... ...
def get_tuned_algo(self, a_dtype: int, b_dtype: int, c_dtype: int, a_shape: List[int], b_shape: List[int], c_shape: List[int], trans_a: bool, trans_b: bool, trans_c: bool, arch: Tuple[int, int], shuffle_type: int, a_inds_shape: List[int], b_inds_shape: List[int], c_inds_shape: List[int], hint: int = 0) -> Tuple[Any, bool]: def get_tuned_algo(self, a_dtype: int, b_dtype: int, c_dtype: int, a_shape: List[int], b_shape: List[int], c_shape: List[int], trans_a: bool, trans_b: bool, trans_c: bool, arch: Tuple[int, int], shuffle_type: int, a_inds_shape: List[int], b_inds_shape: List[int], c_inds_shape: List[int], hint: int = 0) -> Tuple[Any, bool]:
......
...@@ -12,7 +12,7 @@ class ConvGemmOps: ...@@ -12,7 +12,7 @@ class ConvGemmOps:
""" """
... ...
@staticmethod @staticmethod
def indice_conv(allocator, ext_mm, gemm_tuner, all_w_is_krsc: bool, filter_hwio: bool, features: Tensor, filters: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor, arch: Tuple[int, int], num_activate_out: int, inverse: bool = False, subm: bool = False, algo: int = 0, stream_int: int = 0, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_) -> None: def indice_conv(allocator, ext_mm, gemm_tuner, all_w_is_krsc: bool, filter_hwio: bool, features: Tensor, filters: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor, arch: Tuple[int, int], num_activate_out: int, inverse: bool = False, subm: bool = False, algo: int = 0, stream_int: int = 0, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, use_tf32: bool = True) -> None:
""" """
1. this function need to take a out features 1. this function need to take a out features
that from subm first mm. that from subm first mm.
...@@ -37,10 +37,11 @@ class ConvGemmOps: ...@@ -37,10 +37,11 @@ class ConvGemmOps:
act_alpha: act_alpha:
act_beta: act_beta:
act_type: act_type:
use_tf32:
""" """
... ...
@staticmethod @staticmethod
def indice_conv_backward(allocator, ext_mm, gemm_tuner, all_w_is_krsc: bool, filter_hwio: bool, features: Tensor, filters: Tensor, out_bp: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor, arch: Tuple[int, int], inverse: bool = False, subm: bool = False, algo: int = 0, stream_int: int = 0) -> None: def indice_conv_backward(allocator, ext_mm, gemm_tuner, all_w_is_krsc: bool, filter_hwio: bool, features: Tensor, filters: Tensor, out_bp: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor, arch: Tuple[int, int], inverse: bool = False, subm: bool = False, algo: int = 0, stream_int: int = 0, use_tf32: bool = True) -> None:
""" """
Args: Args:
allocator: allocator:
...@@ -58,10 +59,11 @@ class ConvGemmOps: ...@@ -58,10 +59,11 @@ class ConvGemmOps:
subm: subm:
algo: algo:
stream_int: stream_int:
use_tf32:
""" """
... ...
@staticmethod @staticmethod
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_) -> Tuple[int, Any]: def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, arch: Tuple[int, int], is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, bias: Tensor = Tensor(), act_alpha: float = 0.0, act_beta: float = 0.0, act_type: Activation = Activation.None_, use_tf32: bool = True) -> Tuple[int, Any]:
""" """
Args: Args:
allocator: allocator:
...@@ -84,10 +86,11 @@ class ConvGemmOps: ...@@ -84,10 +86,11 @@ class ConvGemmOps:
act_alpha: act_alpha:
act_beta: act_beta:
act_type: act_type:
use_tf32:
""" """
... ...
@staticmethod @staticmethod
def implicit_gemm_backward(allocator, conv_tuner, features: Tensor, filters: Tensor, out_bp: Tensor, pair_fwd: Tensor, pair_bwd: Tensor, pair_mask_fwd_splits: List[Tensor], pair_mask_bwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], mask_argsort_bwd_splits: List[Tensor], mask_output_fwd: Tensor, masks: Tensor, arch: Tuple[int, int], mask_width: int, is_subm: bool, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False) -> None: def implicit_gemm_backward(allocator, conv_tuner, features: Tensor, filters: Tensor, out_bp: Tensor, pair_fwd: Tensor, pair_bwd: Tensor, pair_mask_fwd_splits: List[Tensor], pair_mask_bwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], mask_argsort_bwd_splits: List[Tensor], mask_output_fwd: Tensor, masks: Tensor, arch: Tuple[int, int], mask_width: int, is_subm: bool, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False, use_tf32: bool = True) -> None:
""" """
Args: Args:
allocator: allocator:
...@@ -110,5 +113,6 @@ class ConvGemmOps: ...@@ -110,5 +113,6 @@ class ConvGemmOps:
timer: timer:
auto_fp32_accum: auto_fp32_accum:
fp32_accum: fp32_accum:
use_tf32:
""" """
... ...
...@@ -538,6 +538,8 @@ class GemmTunerSimple(pccm.ParameterizedClass): ...@@ -538,6 +538,8 @@ class GemmTunerSimple(pccm.ParameterizedClass):
code.arg("trans_a, trans_b, trans_c", "bool") code.arg("trans_a, trans_b, trans_c", "bool")
code.arg("arch", "std::tuple<int, int>") code.arg("arch", "std::tuple<int, int>")
code.arg("shuffle_type", "int") code.arg("shuffle_type", "int")
code.arg("use_tf32", "bool", "true")
code.raw(f""" code.raw(f"""
if (trans_c){{ if (trans_c){{
trans_a = !trans_a; trans_a = !trans_a;
...@@ -562,6 +564,12 @@ class GemmTunerSimple(pccm.ParameterizedClass): ...@@ -562,6 +564,12 @@ class GemmTunerSimple(pccm.ParameterizedClass):
if (arch >= std::make_tuple(7, 5) && desp.algo == {pccm.literal(GemmAlgo.Volta.value)}){{ if (arch >= std::make_tuple(7, 5) && desp.algo == {pccm.literal(GemmAlgo.Volta.value)}){{
continue; continue;
}} }}
if (!use_tf32){{
if (desp.tensorop[0] > 0 && a.dtype() == tv::float32 && b.dtype() == tv::float32){{
// tf32 op
continue;
}}
}}
auto lda = a.stride(0); auto lda = a.stride(0);
auto ldb = b.stride(0); auto ldb = b.stride(0);
auto ldc = c.stride(0); auto ldc = c.stride(0);
...@@ -656,6 +664,8 @@ class GemmTunerSimple(pccm.ParameterizedClass): ...@@ -656,6 +664,8 @@ class GemmTunerSimple(pccm.ParameterizedClass):
code.arg("beta", "float", "0.0") code.arg("beta", "float", "0.0")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int") code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.arg("num_run", "int", "5") code.arg("num_run", "int", "5")
code.arg("use_tf32", "bool", "true")
if CUMM_CPU_ONLY_BUILD: if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")") code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
code.raw("return std::make_tuple(GemmTuneResult(), -1.0f);") code.raw("return std::make_tuple(GemmTuneResult(), -1.0f);")
...@@ -677,8 +687,8 @@ class GemmTunerSimple(pccm.ParameterizedClass): ...@@ -677,8 +687,8 @@ class GemmTunerSimple(pccm.ParameterizedClass):
auto n = std::get<1>(mnk); auto n = std::get<1>(mnk);
auto k = std::get<2>(mnk); auto k = std::get<2>(mnk);
auto avail = get_all_available(a, b, c, trans_a, auto avail = get_all_available(a, b, c, trans_a, trans_b,
trans_b, trans_c, arch, shuffle_type); trans_c, arch, shuffle_type, use_tf32);
auto c_ = c.clone_whole_storage(); auto c_ = c.clone_whole_storage();
std::vector<GemmTuneResult> all_profile_res; std::vector<GemmTuneResult> all_profile_res;
std::vector<int> splitk_tests; std::vector<int> splitk_tests;
...@@ -969,6 +979,7 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -969,6 +979,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code.arg("mask_width", "int") code.arg("mask_width", "int")
code.arg("auto_fp32_accum", "bool") code.arg("auto_fp32_accum", "bool")
code.arg("fp32_accum", "bool") code.arg("fp32_accum", "bool")
code.arg("use_tf32", "bool", "true")
code.raw(f""" code.raw(f"""
tv::gemm::ConvOpType op_type_cpp = static_cast<tv::gemm::ConvOpType>(op_type); tv::gemm::ConvOpType op_type_cpp = static_cast<tv::gemm::ConvOpType>(op_type);
...@@ -1010,6 +1021,12 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1010,6 +1021,12 @@ class ConvTunerSimple(pccm.ParameterizedClass):
if (arch >= std::make_tuple(7, 5) && desp.algo == {pccm.literal(GemmAlgo.Volta.value)}){{ if (arch >= std::make_tuple(7, 5) && desp.algo == {pccm.literal(GemmAlgo.Volta.value)}){{
continue; continue;
}} }}
if (!use_tf32){{
if (desp.tensorop[0] > 0 && inp.dtype() == tv::float32 && weight.dtype() == tv::float32 && out.dtype() == tv::float32){{
// tf32 op
continue;
}}
}}
if (arch >= std::make_tuple(7, 0) && is_fp16){{ if (arch >= std::make_tuple(7, 0) && is_fp16){{
// skip simt fp16 kernels if we have tensor core // skip simt fp16 kernels if we have tensor core
if (desp.algo == {pccm.literal(GemmAlgo.Simt.value)}){{ if (desp.algo == {pccm.literal(GemmAlgo.Simt.value)}){{
...@@ -1086,6 +1103,8 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1086,6 +1103,8 @@ class ConvTunerSimple(pccm.ParameterizedClass):
code.arg("auto_fp32_accum", "bool", "true") code.arg("auto_fp32_accum", "bool", "true")
code.arg("fp32_accum", "bool", "false") code.arg("fp32_accum", "bool", "false")
code.arg("num_run", "int", "5") code.arg("num_run", "int", "5")
code.arg("use_tf32", "bool", "true")
if CUMM_CPU_ONLY_BUILD: if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")") code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
return code.ret( return code.ret(
...@@ -1099,7 +1118,7 @@ class ConvTunerSimple(pccm.ParameterizedClass): ...@@ -1099,7 +1118,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
auto avail = get_all_available(inp, weight, output, layout_i, layout_w, auto avail = get_all_available(inp, weight, output, layout_i, layout_w,
layout_o, interleave_i, interleave_w, interleave_o, layout_o, interleave_i, interleave_w, interleave_o,
arch, op_type, mask_width, arch, op_type, mask_width,
auto_fp32_accum, fp32_accum); auto_fp32_accum, fp32_accum, use_tf32);
inp = inp.clone(); inp = inp.clone();
weight = weight.clone(); weight = weight.clone();
output = output.clone(); output = output.clone();
...@@ -1408,6 +1427,7 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -1408,6 +1427,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
code.arg("act_alpha", f"float", "0.0") code.arg("act_alpha", f"float", "0.0")
code.arg("act_beta", f"float", "0.0") code.arg("act_beta", f"float", "0.0")
code.arg("act_type", f"tv::gemm::Activation", "tv::gemm::Activation::kNone", "cumm.tensorview.gemm.Activation = Activation.None_") code.arg("act_type", f"tv::gemm::Activation", "tv::gemm::Activation::kNone", "cumm.tensorview.gemm.Activation = Activation.None_")
code.arg("use_tf32", "bool", "true")
code.raw(f""" code.raw(f"""
int kv_dim, out_channel, kv; int kv_dim, out_channel, kv;
...@@ -1571,7 +1591,9 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -1571,7 +1591,9 @@ class ConvGemmOps(pccm.ParameterizedClass):
{AlgoHint.Fowrard.value}, {AlgoHint.Fowrard.value},
1.0, 1.0,
0.0, 0.0,
stream_int); stream_int,
5, // num_run
use_tf32);
tune_res = std::get<0>(tune_res_time); tune_res = std::get<0>(tune_res_time);
}} }}
...@@ -1640,6 +1662,7 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -1640,6 +1662,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
code.arg("subm", "bool", "false") code.arg("subm", "bool", "false")
code.arg("algo", "int", f"{ConvAlgo.Native.value}") code.arg("algo", "int", f"{ConvAlgo.Native.value}")
code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int") code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int")
code.arg("use_tf32", "bool", "true")
code.raw(f""" code.raw(f"""
int kv_dim, out_channel, kv; int kv_dim, out_channel, kv;
...@@ -1794,7 +1817,9 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -1794,7 +1817,9 @@ class ConvGemmOps(pccm.ParameterizedClass):
{AlgoHint.BackwardInput.value}, {AlgoHint.BackwardInput.value},
1.0, 1.0,
0.0, 0.0,
stream_int); stream_int,
5, // num_run
use_tf32);
tuned_res_dgrad = std::get<0>(tune_res_time); tuned_res_dgrad = std::get<0>(tune_res_time);
}} }}
tv::Tensor a_wgrad, b_wgrad; tv::Tensor a_wgrad, b_wgrad;
...@@ -1852,7 +1877,9 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -1852,7 +1877,9 @@ class ConvGemmOps(pccm.ParameterizedClass):
{AlgoHint.BackwardWeight.value}, {AlgoHint.BackwardWeight.value},
1.0, 1.0,
0.0, 0.0,
stream_int); stream_int,
5, // num_run
use_tf32);
tuned_res_wgrad = std::get<0>(tune_res_time); tuned_res_wgrad = std::get<0>(tune_res_time);
}} }}
...@@ -1966,6 +1993,7 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -1966,6 +1993,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
code.arg("act_alpha", f"float", "0.0") code.arg("act_alpha", f"float", "0.0")
code.arg("act_beta", f"float", "0.0") code.arg("act_beta", f"float", "0.0")
code.arg("act_type", f"tv::gemm::Activation", "tv::gemm::Activation::kNone", "cumm.tensorview.gemm.Activation = Activation.None_") code.arg("act_type", f"tv::gemm::Activation", "tv::gemm::Activation::kNone", "cumm.tensorview.gemm.Activation = Activation.None_")
code.arg("use_tf32", "bool", "true")
if CUMM_CPU_ONLY_BUILD: if CUMM_CPU_ONLY_BUILD:
...@@ -2025,7 +2053,9 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2025,7 +2053,9 @@ class ConvGemmOps(pccm.ParameterizedClass):
1.0, 0.0, 1.0, 0.0,
stream_int, stream_int,
auto_fp32_accum, auto_fp32_accum,
fp32_accum); fp32_accum,
5, // num_run
use_tf32);
tune_res = std::get<0>(tune_res_time); tune_res = std::get<0>(tune_res_time);
}} }}
...@@ -2109,6 +2139,7 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2109,6 +2139,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
"cumm.tensorview.CUDAKernelTimer = CUDAKernelTimer(False)") "cumm.tensorview.CUDAKernelTimer = CUDAKernelTimer(False)")
code.arg("auto_fp32_accum", "bool", "true") code.arg("auto_fp32_accum", "bool", "true")
code.arg("fp32_accum", "bool", "false") code.arg("fp32_accum", "bool", "false")
code.arg("use_tf32", "bool", "true")
if CUMM_CPU_ONLY_BUILD: if CUMM_CPU_ONLY_BUILD:
code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")") code.raw(f"TV_THROW_RT_ERR(\"not implemented for cpu!!!\")")
...@@ -2192,7 +2223,9 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2192,7 +2223,9 @@ class ConvGemmOps(pccm.ParameterizedClass):
1.0, 0.0, 1.0, 0.0,
stream_int, stream_int,
auto_fp32_accum, auto_fp32_accum,
fp32_accum); fp32_accum,
5, // num_run
use_tf32);
dgrad_tune_res = std::get<0>(tune_res_time); dgrad_tune_res = std::get<0>(tune_res_time);
}} }}
if (!wgrad_exists){{ if (!wgrad_exists){{
...@@ -2214,7 +2247,9 @@ class ConvGemmOps(pccm.ParameterizedClass): ...@@ -2214,7 +2247,9 @@ class ConvGemmOps(pccm.ParameterizedClass):
1.0, 0.0, 1.0, 0.0,
stream_int, stream_int,
auto_fp32_accum, auto_fp32_accum,
fp32_accum); fp32_accum,
5, // num_run
use_tf32);
wgrad_tune_res = std::get<0>(tune_res_time); wgrad_tune_res = std::get<0>(tune_res_time);
}} }}
int ws_size = conv_tuner.query_workspace_size(wgrad_tune_res.algo_desp, int ws_size = conv_tuner.query_workspace_size(wgrad_tune_res.algo_desp,
......
...@@ -26,7 +26,7 @@ from spconv.pytorch.core import ThrustSortAllocator ...@@ -26,7 +26,7 @@ from spconv.pytorch.core import ThrustSortAllocator
from spconv.pytorch.cppcore import TorchAllocator, torch_tensor_to_tv, get_current_stream, get_arch, TorchSpconvMatmul from spconv.pytorch.cppcore import TorchAllocator, torch_tensor_to_tv, get_current_stream, get_arch, TorchSpconvMatmul
from spconv.core_cc.csrc.sparse.all import SpconvOps from spconv.core_cc.csrc.sparse.all import SpconvOps
from spconv.core_cc.csrc.sparse.alloc import ExternalAllocator from spconv.core_cc.csrc.sparse.alloc import ExternalAllocator
from spconv.constants import SPCONV_CPP_INDICE_PAIRS, SPCONV_CPP_INDICE_PAIRS_IGEMM, SPCONV_CPP_GEMM, SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE from spconv.constants import SPCONV_CPP_INDICE_PAIRS, SPCONV_CPP_INDICE_PAIRS_IGEMM, SPCONV_CPP_GEMM, SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE, SPCONV_ALLOW_TF32
import spconv.core_cc as _ext import spconv.core_cc as _ext
from spconv.core_cc.csrc.sparse.convops.spops import ConvGemmOps from spconv.core_cc.csrc.sparse.convops.spops import ConvGemmOps
from spconv.core_cc.csrc.sparse.inference import InferenceOps from spconv.core_cc.csrc.sparse.inference import InferenceOps
...@@ -831,7 +831,8 @@ def indice_conv(features: torch.Tensor, ...@@ -831,7 +831,8 @@ def indice_conv(features: torch.Tensor,
FILTER_HWIO, features_tv, filters_tv, FILTER_HWIO, features_tv, filters_tv,
indice_pairs_tv, indice_pair_num_tv, arch, indice_pairs_tv, indice_pair_num_tv, arch,
num_activate_out, inverse, subm, algo.value, num_activate_out, inverse, subm, algo.value,
stream, bias_tv, act_alpha, act_beta, act_type) stream, bias_tv, act_alpha, act_beta, act_type,
use_tf32=SPCONV_ALLOW_TF32)
out_features = alloc.allocated[AllocKeys.OutFeatures] out_features = alloc.allocated[AllocKeys.OutFeatures]
return out_features return out_features
if not features.is_cuda: if not features.is_cuda:
...@@ -1011,7 +1012,8 @@ def indice_conv(features: torch.Tensor, ...@@ -1011,7 +1012,8 @@ def indice_conv(features: torch.Tensor,
alpha=1.0, alpha=1.0,
beta=0.0, beta=0.0,
hint=AlgoHint.Fowrard.value, hint=AlgoHint.Fowrard.value,
stream=stream) stream=stream,
use_tf32=SPCONV_ALLOW_TF32)
# CONV.stream_synchronize(stream) # CONV.stream_synchronize(stream)
# t = time.time() # t = time.time()
with timer.record("forward", stream): with timer.record("forward", stream):
...@@ -1103,7 +1105,7 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -1103,7 +1105,7 @@ def indice_conv_backward(features: torch.Tensor,
features_tv, filters_tv, out_bp_tv, features_tv, filters_tv, out_bp_tv,
indice_pairs_tv, indice_pair_num_tv, indice_pairs_tv, indice_pair_num_tv,
arch, inverse, subm, algo.value, arch, inverse, subm, algo.value,
stream) stream, use_tf32=SPCONV_ALLOW_TF32)
din = alloc.allocated[AllocKeys.DIn] din = alloc.allocated[AllocKeys.DIn]
df = alloc.allocated[AllocKeys.DFilters] df = alloc.allocated[AllocKeys.DFilters]
return din, df return din, df
...@@ -1270,7 +1272,8 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -1270,7 +1272,8 @@ def indice_conv_backward(features: torch.Tensor,
alpha=1.0, alpha=1.0,
beta=0.0, beta=0.0,
hint=AlgoHint.BackwardInput.value, hint=AlgoHint.BackwardInput.value,
stream=stream) stream=stream,
use_tf32=SPCONV_ALLOW_TF32)
if is_KC_not_CK: if is_KC_not_CK:
a_wgrad = out_bp_tv a_wgrad = out_bp_tv
b_wgrad = features_tv b_wgrad = features_tv
...@@ -1317,7 +1320,8 @@ def indice_conv_backward(features: torch.Tensor, ...@@ -1317,7 +1320,8 @@ def indice_conv_backward(features: torch.Tensor,
alpha=1.0, alpha=1.0,
beta=0.0, beta=0.0,
hint=AlgoHint.BackwardWeight.value, hint=AlgoHint.BackwardWeight.value,
stream=stream) stream=stream,
use_tf32=SPCONV_ALLOW_TF32)
# print(tuned_res_wgrad.algo_desp, tuned_res_wgrad.splitk, min_time) # print(tuned_res_wgrad.algo_desp, tuned_res_wgrad.splitk, min_time)
# get workspace size for wgrad # get workspace size for wgrad
if is_KC_not_CK: if is_KC_not_CK:
...@@ -1462,7 +1466,8 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1462,7 +1466,8 @@ def implicit_gemm(features: torch.Tensor,
alloc, CONV_CPP, features_tv, filters_tv, pair_fwd_tv, alloc, CONV_CPP, features_tv, filters_tv, pair_fwd_tv,
pair_mask_fwd_splits_tv, mask_argsort_fwd_splits_tv, pair_mask_fwd_splits_tv, mask_argsort_fwd_splits_tv,
num_activate_out, mask_tv, arch, is_train, is_subm, stream, num_activate_out, mask_tv, arch, is_train, is_subm, stream,
timer_cpp, auto_fp32_accum, fp32_accum, bias_tv, act_alpha, act_beta, act_type) timer_cpp, auto_fp32_accum, fp32_accum, bias_tv, act_alpha, act_beta, act_type,
use_tf32=SPCONV_ALLOW_TF32)
out_features = alloc.allocated[AllocKeys.OutFeatures] out_features = alloc.allocated[AllocKeys.OutFeatures]
mask_output_fwd = alloc.allocated.get(AllocKeys.MaskOutputFwd, None) mask_output_fwd = alloc.allocated.get(AllocKeys.MaskOutputFwd, None)
if is_train: if is_train:
...@@ -1529,7 +1534,8 @@ def implicit_gemm(features: torch.Tensor, ...@@ -1529,7 +1534,8 @@ def implicit_gemm(features: torch.Tensor,
reverse_mask=False, reverse_mask=False,
mask_filter=masks[0].item(), mask_filter=masks[0].item(),
stream=stream, stream=stream,
fp32_accum=fp32_accum) fp32_accum=fp32_accum,
use_tf32=SPCONV_ALLOW_TF32)
mask_width = tune_res.algo_desp.tile_shape[0] mask_width = tune_res.algo_desp.tile_shape[0]
if is_train: if is_train:
mask_output_fwd = torch.empty( mask_output_fwd = torch.empty(
...@@ -1741,7 +1747,8 @@ def implicit_gemm_backward(features: torch.Tensor, ...@@ -1741,7 +1747,8 @@ def implicit_gemm_backward(features: torch.Tensor,
pair_bwd_tv, pair_mask_fwd_splits_tv, pair_mask_bwd_splits_tv, pair_bwd_tv, pair_mask_fwd_splits_tv, pair_mask_bwd_splits_tv,
mask_argsort_fwd_splits_tv, mask_argsort_bwd_splits_tv, mask_argsort_fwd_splits_tv, mask_argsort_bwd_splits_tv,
mask_output_fwd_tv, mask_tv, arch, mask_width, is_subm, stream, mask_output_fwd_tv, mask_tv, arch, mask_width, is_subm, stream,
timer_cpp, auto_fp32_accum, fp32_accum) timer_cpp, auto_fp32_accum, fp32_accum,
use_tf32=SPCONV_ALLOW_TF32)
din = alloc.allocated[AllocKeys.DIn] din = alloc.allocated[AllocKeys.DIn]
dfilters = alloc.allocated[AllocKeys.DFilters] dfilters = alloc.allocated[AllocKeys.DFilters]
return din, dfilters return din, dfilters
...@@ -1817,7 +1824,8 @@ def implicit_gemm_backward(features: torch.Tensor, ...@@ -1817,7 +1824,8 @@ def implicit_gemm_backward(features: torch.Tensor,
reverse_mask=is_subm, reverse_mask=is_subm,
mask_filter=masks[0].item(), mask_filter=masks[0].item(),
stream=stream, stream=stream,
fp32_accum=fp32_accum) fp32_accum=fp32_accum,
use_tf32=SPCONV_ALLOW_TF32)
if wgrad_tune_res is None: if wgrad_tune_res is None:
wgrad_tune_res, _ = CONV.tune_and_cache( wgrad_tune_res, _ = CONV.tune_and_cache(
ConvOpType.kBackwardWeight, ConvOpType.kBackwardWeight,
...@@ -1835,7 +1843,8 @@ def implicit_gemm_backward(features: torch.Tensor, ...@@ -1835,7 +1843,8 @@ def implicit_gemm_backward(features: torch.Tensor,
mask_filter=masks[0].item(), mask_filter=masks[0].item(),
mask_output=tv.Tensor(), mask_output=tv.Tensor(),
mask_width=mask_width, mask_width=mask_width,
stream=stream) stream=stream,
use_tf32=SPCONV_ALLOW_TF32)
workspace_size = CONV.query_workspace_size(wgrad_tune_res.algo_desp, workspace_size = CONV.query_workspace_size(wgrad_tune_res.algo_desp,
wgrad_tune_res.splitk, wgrad_tune_res.splitk,
ConvOpType.kBackwardWeight, ConvOpType.kBackwardWeight,
......
...@@ -34,6 +34,7 @@ from spconv.core_cc.csrc.sparse.convops import GemmTuneResult, ConvTuneResult ...@@ -34,6 +34,7 @@ from spconv.core_cc.csrc.sparse.convops import GemmTuneResult, ConvTuneResult
from spconv.pytorch.core import SparseConvTensor from spconv.pytorch.core import SparseConvTensor
from spconv.test_utils import TestCase from spconv.test_utils import TestCase
from cumm import tensorview as tv from cumm import tensorview as tv
from spconv.constants import SPCONV_ALLOW_TF32
from cumm.conv.bases import NCHW, NHWC, ConvIterAlgo, ConvOpType from cumm.conv.bases import NCHW, NHWC, ConvIterAlgo, ConvOpType
import os import os
from cumm.gemm.codeops import div_up from cumm.gemm.codeops import div_up
...@@ -279,6 +280,8 @@ def _test_impgemm_conv_cuda(subm: bool): ...@@ -279,6 +280,8 @@ def _test_impgemm_conv_cuda(subm: bool):
shapes = [[19, 18, 17]] shapes = [[19, 18, 17]]
batchsizes = [1] batchsizes = [1]
dtypes = [np.float32, np.float16] dtypes = [np.float32, np.float16]
# dtypes = [np.float16]
# dtypes = [np.int8] # dtypes = [np.int8]
test_case = TestCase() test_case = TestCase()
# in_channels = [32] # in_channels = [32]
...@@ -331,9 +334,11 @@ def _test_impgemm_conv_cuda(subm: bool): ...@@ -331,9 +334,11 @@ def _test_impgemm_conv_cuda(subm: bool):
if SPCONV_CPP_GEMM: if SPCONV_CPP_GEMM:
avail_desps = CONV_CPP.get_all_available(inp_tv, weight_tv, output_tv, avail_desps = CONV_CPP.get_all_available(inp_tv, weight_tv, output_tv,
NHWC.layout_type.value, NHWC.layout_type.value, NHWC.layout_type.value, NHWC.layout_type.value,
NHWC.layout_type.value, NHWC.interleave, NHWC.interleave, NHWC.interleave, arch, op_type.value, -1, True, False) NHWC.layout_type.value, NHWC.interleave, NHWC.interleave, NHWC.interleave, arch, op_type.value, -1, True, False,
use_tf32=SPCONV_ALLOW_TF32)
else: else:
avail_desps = CONV.get_all_available(inp_tv, weight_tv, output_tv, NHWC, NHWC, NHWC, arch, op_type, -1) avail_desps = CONV.get_all_available(inp_tv, weight_tv, output_tv, NHWC, NHWC, NHWC, arch, op_type, -1,
use_tf32=SPCONV_ALLOW_TF32)
if op_type == ConvOpType.kForward and tester.check_act: if op_type == ConvOpType.kForward and tester.check_act:
act = tv.gemm.Activation.ReLU act = tv.gemm.Activation.ReLU
else: else:
...@@ -535,9 +540,11 @@ def _test_impgemm_conv_cuda(subm: bool): ...@@ -535,9 +540,11 @@ def _test_impgemm_conv_cuda(subm: bool):
avail_desps = CONV_CPP.get_all_available(inp_tv, weight_tv, output_tv, avail_desps = CONV_CPP.get_all_available(inp_tv, weight_tv, output_tv,
NHWC.layout_type.value, NHWC.layout_type.value, NHWC.layout_type.value, NHWC.layout_type.value,
NHWC.layout_type.value, NHWC.interleave, NHWC.interleave, NHWC.interleave, arch, NHWC.layout_type.value, NHWC.interleave, NHWC.interleave, NHWC.interleave, arch,
ConvOpType.kBackwardWeight.value, mask_width, True, False) ConvOpType.kBackwardWeight.value, mask_width, True, False,
use_tf32=SPCONV_ALLOW_TF32)
else: else:
avail_desps = CONV.get_all_available(inp_tv, weight_tv, output_tv, NHWC, NHWC, NHWC, arch, ConvOpType.kBackwardWeight, mask_width) avail_desps = CONV.get_all_available(inp_tv, weight_tv, output_tv, NHWC, NHWC, NHWC, arch, ConvOpType.kBackwardWeight, mask_width,
use_tf32=SPCONV_ALLOW_TF32)
for desp in avail_desps: for desp in avail_desps:
weight_tv.zero_() weight_tv.zero_()
if subm: if subm:
...@@ -753,9 +760,11 @@ def _test_native_conv_cuda(subm: bool): ...@@ -753,9 +760,11 @@ def _test_native_conv_cuda(subm: bool):
b = weight_tv.select(1, tester.kv // 2) b = weight_tv.select(1, tester.kv // 2)
c = inp_tv c = inp_tv
if SPCONV_CPP_GEMM: if SPCONV_CPP_GEMM:
avail_desps = GEMM_CPP.get_all_available(a, b, c, False, False, False, arch, ShuffleStrideType.ShuffleAC.value) avail_desps = GEMM_CPP.get_all_available(a, b, c, False, False, False, arch, ShuffleStrideType.ShuffleAC.value,
use_tf32=SPCONV_ALLOW_TF32)
else: else:
avail_desps = GEMM.get_all_available(a, b, c, False, False, False, arch, ShuffleStrideType.ShuffleAC) avail_desps = GEMM.get_all_available(a, b, c, False, False, False, arch, ShuffleStrideType.ShuffleAC,
use_tf32=SPCONV_ALLOW_TF32)
for desp in avail_desps: for desp in avail_desps:
if subm: if subm:
...@@ -827,9 +836,11 @@ def _test_native_conv_cuda(subm: bool): ...@@ -827,9 +836,11 @@ def _test_native_conv_cuda(subm: bool):
b = inp_tv b = inp_tv
c = weight_tv.select(1, tester.kv // 2) c = weight_tv.select(1, tester.kv // 2)
if SPCONV_CPP_GEMM: if SPCONV_CPP_GEMM:
avail_desps = GEMM_CPP.get_all_available(a, b, c, True, False, False, arch, ShuffleStrideType.ShuffleAB.value) avail_desps = GEMM_CPP.get_all_available(a, b, c, True, False, False, arch, ShuffleStrideType.ShuffleAB.value,
use_tf32=SPCONV_ALLOW_TF32)
else: else:
avail_desps = GEMM.get_all_available(a, b, c, True, False, False, arch, ShuffleStrideType.ShuffleAB) avail_desps = GEMM.get_all_available(a, b, c, True, False, False, arch, ShuffleStrideType.ShuffleAB,
use_tf32=SPCONV_ALLOW_TF32)
for desp in avail_desps: for desp in avail_desps:
# print(desp, C, K, k, s, p, d) # print(desp, C, K, k, s, p, d)
...@@ -900,8 +911,8 @@ def test_all_algo_unit(): ...@@ -900,8 +911,8 @@ def test_all_algo_unit():
# for i in range(5): # for i in range(5):
_test_impgemm_conv_cuda(True) _test_impgemm_conv_cuda(True)
_test_impgemm_conv_cuda(False) _test_impgemm_conv_cuda(False)
# _test_native_conv_cuda(True) _test_native_conv_cuda(True)
# _test_native_conv_cuda(False) _test_native_conv_cuda(False)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment