Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
0c07559f
Commit
0c07559f
authored
Jul 28, 2022
by
yan.yan
Browse files
working on performance problem
parent
21bb00ae
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
93 additions
and
222 deletions
+93
-222
spconv/algo.py
spconv/algo.py
+1
-1
spconv/constants.py
spconv/constants.py
+3
-3
spconv/core.py
spconv/core.py
+12
-193
spconv/core_cc/csrc/sparse/convops/spops.pyi
spconv/core_cc/csrc/sparse/convops/spops.pyi
+8
-4
spconv/csrc/sparse/convops.py
spconv/csrc/sparse/convops.py
+24
-4
spconv/pytorch/ops.py
spconv/pytorch/ops.py
+32
-11
test/benchmark.py
test/benchmark.py
+13
-6
No files found.
spconv/algo.py
View file @
0c07559f
...
...
@@ -752,7 +752,7 @@ class SimpleConv:
use_f32_as_accum
=
weight
.
dim
(
0
)
*
kv
>
128
*
27
else
:
use_f32_as_accum
=
fp32_accum
use_f32_as_accum
=
False
#
use_f32_as_accum = False
for
algo
in
avail_algos
:
static_key
=
(
layout_i
.
layout_type
.
value
,
layout_w
.
layout_type
.
value
,
...
...
spconv/constants.py
View file @
0c07559f
...
...
@@ -99,9 +99,9 @@ class AllocKeys:
SPCONV_DEBUG_WEIGHT
=
False
SPCONV_CPP_INDICE_PAIRS
=
Tru
e
SPCONV_CPP_INDICE_PAIRS_IGEMM
=
Tru
e
SPCONV_CPP_INDICE_PAIRS
=
Fals
e
SPCONV_CPP_INDICE_PAIRS_IGEMM
=
Fals
e
SPCONV_CPP_GEMM
=
Tru
e
SPCONV_CPP_GEMM
=
Fals
e
SPCONV_FX_TRACE_MODE
=
os
.
getenv
(
"SPCONV_FX_TRACE_MODE"
,
"0"
)
==
"1"
\ No newline at end of file
spconv/core.py
View file @
0c07559f
...
...
@@ -16,10 +16,9 @@ from cumm.gemm.main import gen_shuffle_params_v2 as gen_shuffle_params, GemmAlgo
from
cumm.gemm
import
kernel
from
typing
import
List
from
cumm.gemm.algospec.core
import
TensorOp
from
cumm.conv.main
import
gen_gemm_params
as
gen_conv_params
,
ConvFwdAndBwdInput
,
ConvBwdWeight
,
ConvFwd
,
ConvIterAlgo
,
GemmAlgo
from
cumm.conv.main
import
gen_gemm_params
as
gen_conv_params
,
ConvFwdAndBwdInput
,
ConvBwdWeight
,
ConvIterAlgo
,
GemmAlgo
from
cumm.conv.bases
import
(
NCHW
,
NHWC
,
ConvIterAlgo
,
ConvLayout
,
ConvLayoutType
,
ConvMode
,
ConvOpType
)
from
spconv.algocore
import
get_gemm_algo_desp_from_param
from
spconv.constants
import
NDIM_DONT_CARE
...
...
@@ -41,17 +40,17 @@ class AlgoHint(Enum):
SHUFFLE_SIMT_PARAMS
:
List
[
GemmAlgoParams
]
=
[
*
gen_shuffle_params
((
64
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
,
is_nvrtc
=
True
),
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
),
*
gen_shuffle_params
((
128
,
64
,
32
),
(
64
,
32
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
,
is_nvrtc
=
True
),
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
),
*
gen_shuffle_params
((
128
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
,
is_nvrtc
=
True
),
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
),
*
gen_shuffle_params
(
(
128
,
128
,
32
),
(
64
,
32
,
32
),
[
"s8,s8,s8,s32,s32"
,
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
,
is_nvrtc
=
True
),
kernel
.
GemmAlgo
.
SimtDP4A
,
None
),
*
gen_shuffle_params
((
64
,
64
,
32
),
(
32
,
32
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
,
is_nvrtc
=
True
),
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
),
*
gen_shuffle_params
((
64
,
256
,
8
),
(
32
,
64
,
8
),
[
"f32,f32,f32,f32,f32"
],
"f32,f32,f32,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Simt
,
None
),
# *gen_shuffle_params(
...
...
@@ -84,9 +83,6 @@ SHUFFLE_SIMT_PARAMS: List[GemmAlgoParams] = [
"f32,f32,f32,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Simt
,
None
),
*
gen_shuffle_params
((
32
,
32
,
32
),
(
32
,
32
,
8
),
[
"f32,f32,f32,f32,f32"
],
"f32,f32,f32,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Simt
,
None
),
*
gen_shuffle_params
((
16
,
32
,
8
),
(
16
,
16
,
8
),
[
"f32,f32,f32,f32,f32"
],
"f32,f32,f32,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Simt
,
None
),
# fall back kernels if mat is misaligned for half
# TODO use access-per-vector kernel instead of simt kernel for fallback
*
gen_shuffle_params
((
128
,
128
,
8
),
(
32
,
64
,
8
),
[
"f16,f16,f16,f32,f32"
],
...
...
@@ -169,11 +165,11 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [
(
32
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
8
))),
*
gen_shuffle_params
((
64
,
64
,
32
),
(
32
,
32
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))
,
is_nvrtc
=
True
),
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))),
*
gen_shuffle_params
(
(
128
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))
,
is_nvrtc
=
True
),
TensorOp
((
8
,
8
,
16
))),
# *gen_shuffle_params(
# (128, 128, 32),
# (64, 32, 32), ["s8,s8,s8,s32,s32", "s8,s8,s32,s32,s32"], "", 2,
...
...
@@ -181,15 +177,15 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [
*
gen_shuffle_params
(
(
128
,
256
,
32
),
(
64
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))
,
is_nvrtc
=
True
),
TensorOp
((
8
,
8
,
16
))),
*
gen_shuffle_params
(
(
256
,
128
,
32
),
(
64
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))
,
is_nvrtc
=
True
),
TensorOp
((
8
,
8
,
16
))),
*
gen_shuffle_params
((
128
,
64
,
32
),
(
64
,
32
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))
,
is_nvrtc
=
True
),
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))),
*
gen_shuffle_params
((
64
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))
,
is_nvrtc
=
True
),
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))),
]
# SHUFFLE_TURING_PARAMS = []
...
...
@@ -403,8 +399,6 @@ IMPLGEMM_SIMT_PARAMS = [
increment_k_first
=
True
,
access_per_vector
=
1
),
]
IMPLGEMM_VOLTA_PARAMS
=
[
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
...
...
@@ -668,181 +662,6 @@ IMPLGEMM_TURING_PARAMS = [
# NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOp((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1),
# gen_conv_params(ConvFwdAndBwdInput, )
# all int8 kernels use nvrtc.
*
gen_conv_params
(
ConvFwd
,
(
32
,
32
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
*
gen_conv_params
(
ConvFwd
,
(
32
,
64
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
*
gen_conv_params
(
ConvFwd
,
(
32
,
32
,
64
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
*
gen_conv_params
(
ConvFwd
,
(
32
,
64
,
64
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
*
gen_conv_params
(
ConvFwd
,
(
64
,
128
,
32
),
(
32
,
64
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
*
gen_conv_params
(
ConvFwd
,
(
64
,
64
,
32
),
(
32
,
64
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
*
gen_conv_params
(
ConvFwd
,
(
64
,
64
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
*
gen_conv_params
(
ConvFwd
,
(
64
,
32
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
*
gen_conv_params
(
ConvFwd
,
(
128
,
128
,
64
),
(
64
,
64
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
*
gen_conv_params
(
ConvFwd
,
(
64
,
128
,
64
),
(
32
,
64
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
# *gen_conv_params(ConvFwd, (32, 32, 32), (32, 32, 32),
# NDIM_DONT_CARE,
# ConvIterAlgo.Optimized,
# 2, ["s8,s8,s8,s32,s32"],
# NHWC,
# NHWC,
# NHWC,
# GemmAlgo.Turing,
# TensorOp((8, 8, 16)),
# mask_sparse=True,
# increment_k_first=True,
# access_per_vector=0,
# is_nvrtc=True),
# *gen_conv_params(ConvFwd, (32, 64, 32), (32, 32, 32),
# NDIM_DONT_CARE,
# ConvIterAlgo.Optimized,
# 2, ["s8,s8,s8,s32,s32"],
# NHWC,
# NHWC,
# NHWC,
# GemmAlgo.Turing,
# TensorOp((8, 8, 16)),
# mask_sparse=True,
# increment_k_first=True,
# access_per_vector=0,
# is_nvrtc=True),
# *gen_conv_params(ConvFwd, (32, 32, 64), (32, 32, 32),
# NDIM_DONT_CARE,
# ConvIterAlgo.Optimized,
# 2, ["s8,s8,s8,s32,s32"],
# NHWC,
# NHWC,
# NHWC,
# GemmAlgo.Turing,
# TensorOp((8, 8, 16)),
# mask_sparse=True,
# increment_k_first=True,
# access_per_vector=0,
# is_nvrtc=True),
]
ALL_NATIVE_PARAMS
=
SHUFFLE_SIMT_PARAMS
+
SHUFFLE_TURING_PARAMS
+
SHUFFLE_VOLTA_PARAMS
...
...
spconv/core_cc/csrc/sparse/convops/spops.pyi
View file @
0c07559f
...
...
@@ -11,7 +11,7 @@ class ConvGemmOps:
"""
...
@staticmethod
def indice_conv(allocator, ext_mm, gemm_tuner, all_w_is_krsc: bool, filter_hwio: bool, features: Tensor, filters: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor, num_activate_out: int, inverse: bool = False, subm: bool = False, algo: int = 0, stream_int: int = 0) -> None:
def indice_conv(allocator, ext_mm, gemm_tuner, all_w_is_krsc: bool, filter_hwio: bool, features: Tensor, filters: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor,
arch: Tuple[int, int],
num_activate_out: int, inverse: bool = False, subm: bool = False, algo: int = 0, stream_int: int = 0) -> None:
"""
1. this function need to take a out features
that from subm first mm.
...
...
@@ -26,6 +26,7 @@ class ConvGemmOps:
filters:
indice_pairs:
indice_pair_num:
arch:
num_activate_out:
inverse:
subm:
...
...
@@ -34,7 +35,7 @@ class ConvGemmOps:
"""
...
@staticmethod
def indice_conv_backward(allocator, ext_mm, gemm_tuner, all_w_is_krsc: bool, filter_hwio: bool, features: Tensor, filters: Tensor, out_bp: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor, inverse: bool = False, subm: bool = False, algo: int = 0, stream_int: int = 0) -> None:
def indice_conv_backward(allocator, ext_mm, gemm_tuner, all_w_is_krsc: bool, filter_hwio: bool, features: Tensor, filters: Tensor, out_bp: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor,
arch: Tuple[int, int],
inverse: bool = False, subm: bool = False, algo: int = 0, stream_int: int = 0) -> None:
"""
Args:
allocator:
...
...
@@ -47,6 +48,7 @@ class ConvGemmOps:
out_bp:
indice_pairs:
indice_pair_num:
arch:
inverse:
subm:
algo:
...
...
@@ -54,7 +56,7 @@ class ConvGemmOps:
"""
...
@staticmethod
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False) -> int:
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor,
arch: Tuple[int, int],
is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False) -> int:
"""
Args:
allocator:
...
...
@@ -66,6 +68,7 @@ class ConvGemmOps:
mask_argsort_fwd_splits:
num_activate_out:
masks:
arch:
is_train:
is_subm:
stream_int:
...
...
@@ -75,7 +78,7 @@ class ConvGemmOps:
"""
...
@staticmethod
def implicit_gemm_backward(allocator, conv_tuner, features: Tensor, filters: Tensor, out_bp: Tensor, pair_fwd: Tensor, pair_bwd: Tensor, pair_mask_fwd_splits: List[Tensor], pair_mask_bwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], mask_argsort_bwd_splits: List[Tensor], mask_output_fwd: Tensor, masks: Tensor, mask_width: int, is_subm: bool, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False) -> None:
def implicit_gemm_backward(allocator, conv_tuner, features: Tensor, filters: Tensor, out_bp: Tensor, pair_fwd: Tensor, pair_bwd: Tensor, pair_mask_fwd_splits: List[Tensor], pair_mask_bwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], mask_argsort_bwd_splits: List[Tensor], mask_output_fwd: Tensor, masks: Tensor,
arch: Tuple[int, int],
mask_width: int, is_subm: bool, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False) -> None:
"""
Args:
allocator:
...
...
@@ -91,6 +94,7 @@ class ConvGemmOps:
mask_argsort_bwd_splits:
mask_output_fwd:
masks:
arch:
mask_width:
is_subm:
stream_int:
...
...
spconv/csrc/sparse/convops.py
View file @
0c07559f
...
...
@@ -1377,6 +1377,8 @@ class ConvGemmOps(pccm.ParameterizedClass):
code
.
arg
(
"features, filters, indice_pairs"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pair_num"
,
"tv::Tensor"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"num_activate_out"
,
"int"
)
code
.
arg
(
"inverse"
,
"bool"
,
"false"
)
code
.
arg
(
"subm"
,
"bool"
,
"false"
)
...
...
@@ -1489,7 +1491,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
}}
}}
TV_ASSERT_RT_ERR(nhot_profile > 0, "this shouldn't happen");
auto arch = get_compute_capability();
//
auto arch = get_compute_capability();
auto a_shape = a.shape();
auto c_shape = c.shape();
int sac_shuffle_type = static_cast<int>(tv::gemm::ShuffleStrideType::kShuffleAC);
...
...
@@ -1584,6 +1586,8 @@ class ConvGemmOps(pccm.ParameterizedClass):
code
.
arg
(
"all_w_is_krsc, filter_hwio"
,
"bool"
)
code
.
arg
(
"features, filters, out_bp, indice_pairs"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pair_num"
,
"tv::Tensor"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"inverse"
,
"bool"
,
"false"
)
code
.
arg
(
"subm"
,
"bool"
,
"false"
)
code
.
arg
(
"algo"
,
"int"
,
f
"
{
ConvAlgo
.
Native
.
value
}
"
)
...
...
@@ -1594,6 +1598,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
std::vector<int64_t> filter_shape_per_kv;
auto prev_filter_shape_vec = filters.shape_vector();
bool is_KC_not_CK;
if (!all_w_is_krsc){{
kv_dim = 0;
is_KC_not_CK = !filter_hwio;
...
...
@@ -1700,7 +1705,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
}}
}}
TV_ASSERT_RT_ERR(nhot_profile > 0, "this shouldn't happen");
auto arch = get_compute_capability();
//
auto arch = get_compute_capability();
int sac_shuffle_type = static_cast<int>(tv::gemm::ShuffleStrideType::kShuffleAC);
int sab_shuffle_type = static_cast<int>(tv::gemm::ShuffleStrideType::kShuffleAB);
...
...
@@ -1899,6 +1904,8 @@ class ConvGemmOps(pccm.ParameterizedClass):
"std::vector<tv::Tensor>"
)
code
.
arg
(
"num_activate_out"
,
"int"
)
code
.
arg
(
"masks"
,
"tv::Tensor"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"is_train, is_subm"
,
"bool"
,
"false"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
arg
(
"timer"
,
"tv::CUDAKernelTimer"
,
"tv::CUDAKernelTimer(false)"
,
...
...
@@ -1926,7 +1933,11 @@ class ConvGemmOps(pccm.ParameterizedClass):
out_features = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
OutFeatures
)
}
,
{{num_activate_out, out_channel}}, features.dtype(), features.device(), stream_int);
}}
auto arch = get_compute_capability();
// auto start_ev = tv::CUDAEvent();
// start_ev.record(stream_int);
// auto arch = get_compute_capability();
constexpr auto kForwardInt = static_cast<int>(tv::gemm::ConvOpType::kForward);
constexpr auto kChannelLastInt = static_cast<int>(tv::gemm::ConvLayoutType::kChannelLast);
auto tuned_res_exist = conv_tuner.get_tuned_algo(
...
...
@@ -1959,6 +1970,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
fp32_accum);
tune_res = std::get<0>(tune_res_time);
}}
int mask_width = tune_res.algo_desp.tile_shape[0];
tv::Tensor mask_output_fwd;
std::vector<tv::Tensor> mask_output_fwd_splits;
...
...
@@ -1974,6 +1986,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
mask_output_fwd_splits.push_back(tv::Tensor());
}}
}}
for (int j = 0; j < num_split; ++j){{
float beta = j == 0 ? 0 : 1;
conv_tuner.run_with_tuned_result(
...
...
@@ -1995,6 +2008,11 @@ class ConvGemmOps(pccm.ParameterizedClass):
false, // verbose
timer);
}}
// auto end_ev = tv::CUDAEvent();
// end_ev.record(stream_int);
// tv::ssprint(tune_res.algo_desp.__repr__(), "WTF", exists,
// features.shape(), filters.shape(), out_features.shape(), tv::CUDAEvent::sync_and_duration(start_ev, end_ev));
return mask_width;
"""
)
return
code
.
ret
(
"int"
)
...
...
@@ -2013,6 +2031,8 @@ class ConvGemmOps(pccm.ParameterizedClass):
code
.
arg
(
"mask_output_fwd"
,
"tv::Tensor"
)
code
.
arg
(
"masks"
,
"tv::Tensor"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"mask_width"
,
"int"
)
code
.
arg
(
"is_subm"
,
"bool"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
...
...
@@ -2056,7 +2076,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
constexpr auto kChannelLastInt = static_cast<int>(tv::gemm::ConvLayoutType::kChannelLast);
auto arch = get_compute_capability();
//
auto arch = get_compute_capability();
auto dgrad_tuned_res_exist = conv_tuner.get_tuned_algo(
kBackwardInputInt,
...
...
spconv/pytorch/ops.py
View file @
0c07559f
...
...
@@ -419,10 +419,16 @@ def get_indice_pairs_implicit_gemm(
is_mask_split
=
algo
==
ConvAlgo
.
MaskSplitImplicitGemm
mask_split_count
=
2
if
is_mask_split
else
1
if
subm
:
if
is_train
:
pair
=
torch
.
full
((
2
,
kv
,
indices
.
shape
[
0
]),
-
1
,
dtype
=
indices
.
dtype
,
device
=
indices
.
device
)
else
:
pair
=
torch
.
full
((
1
,
kv
,
indices
.
shape
[
0
]),
-
1
,
dtype
=
indices
.
dtype
,
device
=
indices
.
device
)
else
:
# for regular conv, pair-in not equal to pair-out
pair
=
torch
.
full
((
kv
,
indices
.
shape
[
0
]),
...
...
@@ -476,6 +482,7 @@ def get_indice_pairs_implicit_gemm(
ksize
=
ksize
,
dilation
=
dilation
,
indice_pair_mask
=
pair_mask_tv
,
backward
=
is_train
,
stream_int
=
stream
)
# torch.cuda.synchronize()
# print("SUBM0", time.time() - t)
...
...
@@ -505,10 +512,12 @@ def get_indice_pairs_implicit_gemm(
CONV
.
stream_synchronize
(
stream
)
print
(
"SUBM"
,
time
.
time
()
-
t
)
if
is_train
:
return
(
out_inds
,
indice_num_per_loc
,
pair
[
0
],
pair
[
1
],
pair_mask_in_splits
,
[],
mask_argsort_in_splits
,
[],
masks
)
else
:
return
(
out_inds
,
indice_num_per_loc
,
pair
[
0
],
torch
.
Tensor
(),
pair_mask_in_splits
,
[],
mask_argsort_in_splits
,
[],
masks
)
else
:
if
DEBUG
:
...
...
@@ -753,11 +762,15 @@ def indice_conv(features: torch.Tensor,
indice_pair_num_tv
=
torch_tensor_to_tv
(
indice_pair_num
)
filters_tv
=
torch_tensor_to_tv
(
filters
)
stream
=
0
arch
=
(
0
,
0
)
if
features
.
is_cuda
:
# plain get_arch by cuda api is VERY SLOW.
arch
=
get_arch
()
stream
=
get_current_stream
()
ConvGemmOps
.
indice_conv
(
alloc
,
ext_mm
,
GEMM_CPP
,
ALL_WEIGHT_IS_KRSC
,
FILTER_HWIO
,
features_tv
,
filters_tv
,
indice_pairs_tv
,
indice_pair_num_tv
,
arch
,
num_activate_out
,
inverse
,
subm
,
algo
.
value
,
stream
)
out_features
=
alloc
.
allocated
[
AllocKeys
.
OutFeatures
]
...
...
@@ -996,12 +1009,16 @@ def indice_conv_backward(features: torch.Tensor,
indice_pair_num_tv
=
torch_tensor_to_tv
(
indice_pair_num
)
filters_tv
=
torch_tensor_to_tv
(
filters
)
stream
=
0
arch
=
(
0
,
0
)
if
features
.
is_cuda
:
stream
=
get_current_stream
()
arch
=
get_arch
()
ConvGemmOps
.
indice_conv_backward
(
alloc
,
ext_mm
,
GEMM_CPP
,
ALL_WEIGHT_IS_KRSC
,
FILTER_HWIO
,
features_tv
,
filters_tv
,
out_bp_tv
,
indice_pairs_tv
,
indice_pair_num_tv
,
arch
,
inverse
,
subm
,
algo
.
value
,
stream
)
din
=
alloc
.
allocated
[
AllocKeys
.
DIn
]
df
=
alloc
.
allocated
[
AllocKeys
.
DFilters
]
...
...
@@ -1347,10 +1364,12 @@ def implicit_gemm(features: torch.Tensor,
auto_fp32_accum
=
fp32_accum
is
None
if
fp32_accum
is
None
:
fp32_accum
=
False
arch
=
get_arch
()
mask_width
=
ConvGemmOps
.
implicit_gemm
(
alloc
,
CONV_CPP
,
features_tv
,
filters_tv
,
pair_fwd_tv
,
pair_mask_fwd_splits_tv
,
mask_argsort_fwd_splits_tv
,
num_activate_out
,
mask_tv
,
is_train
,
is_subm
,
stream
,
timer_cpp
,
num_activate_out
,
mask_tv
,
arch
,
is_train
,
is_subm
,
stream
,
timer_cpp
,
auto_fp32_accum
,
fp32_accum
)
out_features
=
alloc
.
allocated
[
AllocKeys
.
OutFeatures
]
mask_output_fwd
=
alloc
.
allocated
.
get
(
AllocKeys
.
MaskOutputFwd
,
None
)
...
...
@@ -1441,7 +1460,7 @@ def implicit_gemm(features: torch.Tensor,
# CONV.stream_synchronize(stream)
# t = time.time()
#
print(tune_res.algo_desp)
print
(
tune_res
.
algo_desp
,
"REF"
,
features_tv
.
shape
,
filters
.
shape
)
# with tv.measure_and_print("f16 time"):
with
timer
.
record
(
"implicit_gemm"
,
stream
):
for
j
in
range
(
num_split
):
...
...
@@ -1613,11 +1632,13 @@ def implicit_gemm_backward(features: torch.Tensor,
auto_fp32_accum
=
fp32_accum
is
None
if
fp32_accum
is
None
:
fp32_accum
=
False
arch
=
get_arch
()
ConvGemmOps
.
implicit_gemm_backward
(
alloc
,
CONV_CPP
,
features_tv
,
filters_tv
,
out_bp_tv
,
pair_fwd_tv
,
pair_bwd_tv
,
pair_mask_fwd_splits_tv
,
pair_mask_bwd_splits_tv
,
mask_argsort_fwd_splits_tv
,
mask_argsort_bwd_splits_tv
,
mask_output_fwd_tv
,
mask_tv
,
mask_width
,
is_subm
,
stream
,
mask_output_fwd_tv
,
mask_tv
,
arch
,
mask_width
,
is_subm
,
stream
,
timer_cpp
,
auto_fp32_accum
,
fp32_accum
)
din
=
alloc
.
allocated
[
AllocKeys
.
DIn
]
dfilters
=
alloc
.
allocated
[
AllocKeys
.
DFilters
]
...
...
test/benchmark.py
View file @
0c07559f
...
...
@@ -113,7 +113,7 @@ class Net(nn.Module):
# nn.BatchNorm1d(32),
# nn.ReLU(),
# spconv.SparseConv3d(64, 64, 2, 2, bias=False, indice_key="m0"),
spconv
.
SparseMaxPool3d
(
2
,
2
,
algo
=
pool_algo
,
record_voxel_count
=
True
),
spconv
.
SparseMaxPool3d
(
2
,
2
,
algo
=
pool_algo
),
spconv
.
SubMConv3d
(
64
,
96
,
3
,
...
...
@@ -312,7 +312,7 @@ def sort_bench():
for
i
in
range
(
10
):
a_tv_1
=
a_tv
.
clone
()
SpconvOps
.
sort_1d_by_key
(
a_tv_1
[
0
],
mask_argsort_tv
[
0
])
import
json
def
main
():
import
pickle
...
...
@@ -332,7 +332,8 @@ def main():
voxels_th
=
torch
.
from_numpy
(
voxels
).
to
(
device
).
to
(
dtype
)
coors_th
=
torch
.
from_numpy
(
coors
).
to
(
device
).
int
()
voxels_th
.
requires_grad
=
True
algo
=
spconv
.
ConvAlgo
.
Native
algo
=
spconv
.
ConvAlgo
.
MaskImplicitGemm
print
(
"ALGO"
)
# 3080 Laptop
# MaskImpGemm: 11.2ms
# MaskSplitImpGemm: 12.2ms
...
...
@@ -355,7 +356,7 @@ def main():
# MaskImpGemm: 51.0ms
# MaskSplitImpGemm: 41.1ms
# algo = None
net
=
Net
(
spatial_shape
,
algo
).
to
(
device
).
eval
().
to
(
dtype
).
train
()
net
=
Net
(
spatial_shape
,
algo
).
to
(
device
).
eval
().
to
(
dtype
)
#
.train()
# net.load_state_dict(net.state_dict())
spconv
.
assign_name_for_sparse_modules
(
net
)
print
(
coors_th
.
shape
)
...
...
@@ -368,13 +369,13 @@ def main():
print
(
out
.
spatial_shape
,
out
.
features
.
mean
(),
out
.
features
.
max
(),
out
.
features
.
min
())
times
=
[]
show_metrics
=
False
with
torch
.
no_grad
():
for
i
in
range
(
20
):
print
(
"------------"
)
torch
.
cuda
.
synchronize
()
t
=
time
.
time
()
out_nograd
=
net
(
voxels_th
,
coors_th
,
1
,
False
)
timer
=
out_nograd
.
_timer
out_nograd
=
net
(
voxels_th
,
coors_th
,
1
,
show_metrics
)
# res = timer.collect_by_name("forward", timer.get_all_pair_time())
# res2 = timer.collect_by_name("forward0", timer.get_all_pair_time())
...
...
@@ -385,6 +386,12 @@ def main():
torch
.
cuda
.
synchronize
()
# sort_bench()
times
.
append
(
time
.
time
()
-
t
)
if
show_metrics
:
timer
=
out_nograd
.
_timer
items
=
list
(
timer
.
get_all_pair_time
().
items
())
items
.
sort
(
key
=
lambda
x
:
x
[
0
])
print
(
json
.
dumps
(
dict
(
items
),
indent
=
2
))
# state = net.state_dict()
# state.pop("net.2.max_num_voxels_during_training")
# net.load_state_dict(state)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment