Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
0c07559f
Commit
0c07559f
authored
Jul 28, 2022
by
yan.yan
Browse files
working on performance problem
parent
21bb00ae
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
93 additions
and
222 deletions
+93
-222
spconv/algo.py
spconv/algo.py
+1
-1
spconv/constants.py
spconv/constants.py
+3
-3
spconv/core.py
spconv/core.py
+12
-193
spconv/core_cc/csrc/sparse/convops/spops.pyi
spconv/core_cc/csrc/sparse/convops/spops.pyi
+8
-4
spconv/csrc/sparse/convops.py
spconv/csrc/sparse/convops.py
+24
-4
spconv/pytorch/ops.py
spconv/pytorch/ops.py
+32
-11
test/benchmark.py
test/benchmark.py
+13
-6
No files found.
spconv/algo.py
View file @
0c07559f
...
@@ -752,7 +752,7 @@ class SimpleConv:
...
@@ -752,7 +752,7 @@ class SimpleConv:
use_f32_as_accum
=
weight
.
dim
(
0
)
*
kv
>
128
*
27
use_f32_as_accum
=
weight
.
dim
(
0
)
*
kv
>
128
*
27
else
:
else
:
use_f32_as_accum
=
fp32_accum
use_f32_as_accum
=
fp32_accum
use_f32_as_accum
=
False
#
use_f32_as_accum = False
for
algo
in
avail_algos
:
for
algo
in
avail_algos
:
static_key
=
(
layout_i
.
layout_type
.
value
,
static_key
=
(
layout_i
.
layout_type
.
value
,
layout_w
.
layout_type
.
value
,
layout_w
.
layout_type
.
value
,
...
...
spconv/constants.py
View file @
0c07559f
...
@@ -99,9 +99,9 @@ class AllocKeys:
...
@@ -99,9 +99,9 @@ class AllocKeys:
SPCONV_DEBUG_WEIGHT
=
False
SPCONV_DEBUG_WEIGHT
=
False
SPCONV_CPP_INDICE_PAIRS
=
Tru
e
SPCONV_CPP_INDICE_PAIRS
=
Fals
e
SPCONV_CPP_INDICE_PAIRS_IGEMM
=
Tru
e
SPCONV_CPP_INDICE_PAIRS_IGEMM
=
Fals
e
SPCONV_CPP_GEMM
=
Tru
e
SPCONV_CPP_GEMM
=
Fals
e
SPCONV_FX_TRACE_MODE
=
os
.
getenv
(
"SPCONV_FX_TRACE_MODE"
,
"0"
)
==
"1"
SPCONV_FX_TRACE_MODE
=
os
.
getenv
(
"SPCONV_FX_TRACE_MODE"
,
"0"
)
==
"1"
\ No newline at end of file
spconv/core.py
View file @
0c07559f
...
@@ -16,10 +16,9 @@ from cumm.gemm.main import gen_shuffle_params_v2 as gen_shuffle_params, GemmAlgo
...
@@ -16,10 +16,9 @@ from cumm.gemm.main import gen_shuffle_params_v2 as gen_shuffle_params, GemmAlgo
from
cumm.gemm
import
kernel
from
cumm.gemm
import
kernel
from
typing
import
List
from
typing
import
List
from
cumm.gemm.algospec.core
import
TensorOp
from
cumm.gemm.algospec.core
import
TensorOp
from
cumm.conv.main
import
gen_gemm_params
as
gen_conv_params
,
ConvFwdAndBwdInput
,
ConvBwdWeight
,
ConvFwd
,
ConvIterAlgo
,
GemmAlgo
from
cumm.conv.main
import
gen_gemm_params
as
gen_conv_params
,
ConvFwdAndBwdInput
,
ConvBwdWeight
,
ConvIterAlgo
,
GemmAlgo
from
cumm.conv.bases
import
(
NCHW
,
NHWC
,
ConvIterAlgo
,
ConvLayout
,
from
cumm.conv.bases
import
(
NCHW
,
NHWC
,
ConvIterAlgo
,
ConvLayout
,
ConvLayoutType
,
ConvMode
,
ConvOpType
)
ConvLayoutType
,
ConvMode
,
ConvOpType
)
from
spconv.algocore
import
get_gemm_algo_desp_from_param
from
spconv.constants
import
NDIM_DONT_CARE
from
spconv.constants
import
NDIM_DONT_CARE
...
@@ -41,17 +40,17 @@ class AlgoHint(Enum):
...
@@ -41,17 +40,17 @@ class AlgoHint(Enum):
SHUFFLE_SIMT_PARAMS
:
List
[
GemmAlgoParams
]
=
[
SHUFFLE_SIMT_PARAMS
:
List
[
GemmAlgoParams
]
=
[
*
gen_shuffle_params
((
64
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
*
gen_shuffle_params
((
64
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
,
is_nvrtc
=
True
),
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
),
*
gen_shuffle_params
((
128
,
64
,
32
),
(
64
,
32
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
*
gen_shuffle_params
((
128
,
64
,
32
),
(
64
,
32
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
,
is_nvrtc
=
True
),
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
),
*
gen_shuffle_params
((
128
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
*
gen_shuffle_params
((
128
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
,
is_nvrtc
=
True
),
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
128
,
128
,
32
),
(
128
,
128
,
32
),
(
64
,
32
,
32
),
[
"s8,s8,s8,s32,s32"
,
"s8,s8,s32,s32,s32"
],
""
,
2
,
(
64
,
32
,
32
),
[
"s8,s8,s8,s32,s32"
,
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
,
is_nvrtc
=
True
),
kernel
.
GemmAlgo
.
SimtDP4A
,
None
),
*
gen_shuffle_params
((
64
,
64
,
32
),
(
32
,
32
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
*
gen_shuffle_params
((
64
,
64
,
32
),
(
32
,
32
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
,
is_nvrtc
=
True
),
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
),
*
gen_shuffle_params
((
64
,
256
,
8
),
(
32
,
64
,
8
),
[
"f32,f32,f32,f32,f32"
],
*
gen_shuffle_params
((
64
,
256
,
8
),
(
32
,
64
,
8
),
[
"f32,f32,f32,f32,f32"
],
"f32,f32,f32,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Simt
,
None
),
"f32,f32,f32,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Simt
,
None
),
# *gen_shuffle_params(
# *gen_shuffle_params(
...
@@ -84,9 +83,6 @@ SHUFFLE_SIMT_PARAMS: List[GemmAlgoParams] = [
...
@@ -84,9 +83,6 @@ SHUFFLE_SIMT_PARAMS: List[GemmAlgoParams] = [
"f32,f32,f32,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Simt
,
None
),
"f32,f32,f32,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Simt
,
None
),
*
gen_shuffle_params
((
32
,
32
,
32
),
(
32
,
32
,
8
),
[
"f32,f32,f32,f32,f32"
],
*
gen_shuffle_params
((
32
,
32
,
32
),
(
32
,
32
,
8
),
[
"f32,f32,f32,f32,f32"
],
"f32,f32,f32,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Simt
,
None
),
"f32,f32,f32,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Simt
,
None
),
*
gen_shuffle_params
((
16
,
32
,
8
),
(
16
,
16
,
8
),
[
"f32,f32,f32,f32,f32"
],
"f32,f32,f32,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Simt
,
None
),
# fall back kernels if mat is misaligned for half
# fall back kernels if mat is misaligned for half
# TODO use access-per-vector kernel instead of simt kernel for fallback
# TODO use access-per-vector kernel instead of simt kernel for fallback
*
gen_shuffle_params
((
128
,
128
,
8
),
(
32
,
64
,
8
),
[
"f16,f16,f16,f32,f32"
],
*
gen_shuffle_params
((
128
,
128
,
8
),
(
32
,
64
,
8
),
[
"f16,f16,f16,f32,f32"
],
...
@@ -169,11 +165,11 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [
...
@@ -169,11 +165,11 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [
(
32
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
(
32
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
8
))),
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
8
))),
*
gen_shuffle_params
((
64
,
64
,
32
),
(
32
,
32
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
*
gen_shuffle_params
((
64
,
64
,
32
),
(
32
,
32
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))
,
is_nvrtc
=
True
),
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
128
,
128
,
32
),
(
128
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
(
32
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))
,
is_nvrtc
=
True
),
TensorOp
((
8
,
8
,
16
))),
# *gen_shuffle_params(
# *gen_shuffle_params(
# (128, 128, 32),
# (128, 128, 32),
# (64, 32, 32), ["s8,s8,s8,s32,s32", "s8,s8,s32,s32,s32"], "", 2,
# (64, 32, 32), ["s8,s8,s8,s32,s32", "s8,s8,s32,s32,s32"], "", 2,
...
@@ -181,15 +177,15 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [
...
@@ -181,15 +177,15 @@ SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
128
,
256
,
32
),
(
128
,
256
,
32
),
(
64
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
(
64
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))
,
is_nvrtc
=
True
),
TensorOp
((
8
,
8
,
16
))),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
256
,
128
,
32
),
(
256
,
128
,
32
),
(
64
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
(
64
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))
,
is_nvrtc
=
True
),
TensorOp
((
8
,
8
,
16
))),
*
gen_shuffle_params
((
128
,
64
,
32
),
(
64
,
32
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
*
gen_shuffle_params
((
128
,
64
,
32
),
(
64
,
32
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))
,
is_nvrtc
=
True
),
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))),
*
gen_shuffle_params
((
64
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
*
gen_shuffle_params
((
64
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))
,
is_nvrtc
=
True
),
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))),
]
]
# SHUFFLE_TURING_PARAMS = []
# SHUFFLE_TURING_PARAMS = []
...
@@ -403,8 +399,6 @@ IMPLGEMM_SIMT_PARAMS = [
...
@@ -403,8 +399,6 @@ IMPLGEMM_SIMT_PARAMS = [
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
]
]
IMPLGEMM_VOLTA_PARAMS
=
[
IMPLGEMM_VOLTA_PARAMS
=
[
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
32
),
(
32
,
32
,
32
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
NDIM_DONT_CARE
,
...
@@ -668,181 +662,6 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -668,181 +662,6 @@ IMPLGEMM_TURING_PARAMS = [
# NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOp((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1),
# NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOp((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1),
# gen_conv_params(ConvFwdAndBwdInput, )
# gen_conv_params(ConvFwdAndBwdInput, )
# all int8 kernels use nvrtc.
*
gen_conv_params
(
ConvFwd
,
(
32
,
32
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
*
gen_conv_params
(
ConvFwd
,
(
32
,
64
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
*
gen_conv_params
(
ConvFwd
,
(
32
,
32
,
64
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
*
gen_conv_params
(
ConvFwd
,
(
32
,
64
,
64
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
*
gen_conv_params
(
ConvFwd
,
(
64
,
128
,
32
),
(
32
,
64
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
*
gen_conv_params
(
ConvFwd
,
(
64
,
64
,
32
),
(
32
,
64
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
*
gen_conv_params
(
ConvFwd
,
(
64
,
64
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
*
gen_conv_params
(
ConvFwd
,
(
64
,
32
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
*
gen_conv_params
(
ConvFwd
,
(
128
,
128
,
64
),
(
64
,
64
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
*
gen_conv_params
(
ConvFwd
,
(
64
,
128
,
64
),
(
32
,
64
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
# *gen_conv_params(ConvFwd, (32, 32, 32), (32, 32, 32),
# NDIM_DONT_CARE,
# ConvIterAlgo.Optimized,
# 2, ["s8,s8,s8,s32,s32"],
# NHWC,
# NHWC,
# NHWC,
# GemmAlgo.Turing,
# TensorOp((8, 8, 16)),
# mask_sparse=True,
# increment_k_first=True,
# access_per_vector=0,
# is_nvrtc=True),
# *gen_conv_params(ConvFwd, (32, 64, 32), (32, 32, 32),
# NDIM_DONT_CARE,
# ConvIterAlgo.Optimized,
# 2, ["s8,s8,s8,s32,s32"],
# NHWC,
# NHWC,
# NHWC,
# GemmAlgo.Turing,
# TensorOp((8, 8, 16)),
# mask_sparse=True,
# increment_k_first=True,
# access_per_vector=0,
# is_nvrtc=True),
# *gen_conv_params(ConvFwd, (32, 32, 64), (32, 32, 32),
# NDIM_DONT_CARE,
# ConvIterAlgo.Optimized,
# 2, ["s8,s8,s8,s32,s32"],
# NHWC,
# NHWC,
# NHWC,
# GemmAlgo.Turing,
# TensorOp((8, 8, 16)),
# mask_sparse=True,
# increment_k_first=True,
# access_per_vector=0,
# is_nvrtc=True),
]
]
ALL_NATIVE_PARAMS
=
SHUFFLE_SIMT_PARAMS
+
SHUFFLE_TURING_PARAMS
+
SHUFFLE_VOLTA_PARAMS
ALL_NATIVE_PARAMS
=
SHUFFLE_SIMT_PARAMS
+
SHUFFLE_TURING_PARAMS
+
SHUFFLE_VOLTA_PARAMS
...
...
spconv/core_cc/csrc/sparse/convops/spops.pyi
View file @
0c07559f
...
@@ -11,7 +11,7 @@ class ConvGemmOps:
...
@@ -11,7 +11,7 @@ class ConvGemmOps:
"""
"""
...
...
@staticmethod
@staticmethod
def indice_conv(allocator, ext_mm, gemm_tuner, all_w_is_krsc: bool, filter_hwio: bool, features: Tensor, filters: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor, num_activate_out: int, inverse: bool = False, subm: bool = False, algo: int = 0, stream_int: int = 0) -> None:
def indice_conv(allocator, ext_mm, gemm_tuner, all_w_is_krsc: bool, filter_hwio: bool, features: Tensor, filters: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor,
arch: Tuple[int, int],
num_activate_out: int, inverse: bool = False, subm: bool = False, algo: int = 0, stream_int: int = 0) -> None:
"""
"""
1. this function need to take a out features
1. this function need to take a out features
that from subm first mm.
that from subm first mm.
...
@@ -26,6 +26,7 @@ class ConvGemmOps:
...
@@ -26,6 +26,7 @@ class ConvGemmOps:
filters:
filters:
indice_pairs:
indice_pairs:
indice_pair_num:
indice_pair_num:
arch:
num_activate_out:
num_activate_out:
inverse:
inverse:
subm:
subm:
...
@@ -34,7 +35,7 @@ class ConvGemmOps:
...
@@ -34,7 +35,7 @@ class ConvGemmOps:
"""
"""
...
...
@staticmethod
@staticmethod
def indice_conv_backward(allocator, ext_mm, gemm_tuner, all_w_is_krsc: bool, filter_hwio: bool, features: Tensor, filters: Tensor, out_bp: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor, inverse: bool = False, subm: bool = False, algo: int = 0, stream_int: int = 0) -> None:
def indice_conv_backward(allocator, ext_mm, gemm_tuner, all_w_is_krsc: bool, filter_hwio: bool, features: Tensor, filters: Tensor, out_bp: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor,
arch: Tuple[int, int],
inverse: bool = False, subm: bool = False, algo: int = 0, stream_int: int = 0) -> None:
"""
"""
Args:
Args:
allocator:
allocator:
...
@@ -47,6 +48,7 @@ class ConvGemmOps:
...
@@ -47,6 +48,7 @@ class ConvGemmOps:
out_bp:
out_bp:
indice_pairs:
indice_pairs:
indice_pair_num:
indice_pair_num:
arch:
inverse:
inverse:
subm:
subm:
algo:
algo:
...
@@ -54,7 +56,7 @@ class ConvGemmOps:
...
@@ -54,7 +56,7 @@ class ConvGemmOps:
"""
"""
...
...
@staticmethod
@staticmethod
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False) -> int:
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor,
arch: Tuple[int, int],
is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False) -> int:
"""
"""
Args:
Args:
allocator:
allocator:
...
@@ -66,6 +68,7 @@ class ConvGemmOps:
...
@@ -66,6 +68,7 @@ class ConvGemmOps:
mask_argsort_fwd_splits:
mask_argsort_fwd_splits:
num_activate_out:
num_activate_out:
masks:
masks:
arch:
is_train:
is_train:
is_subm:
is_subm:
stream_int:
stream_int:
...
@@ -75,7 +78,7 @@ class ConvGemmOps:
...
@@ -75,7 +78,7 @@ class ConvGemmOps:
"""
"""
...
...
@staticmethod
@staticmethod
def implicit_gemm_backward(allocator, conv_tuner, features: Tensor, filters: Tensor, out_bp: Tensor, pair_fwd: Tensor, pair_bwd: Tensor, pair_mask_fwd_splits: List[Tensor], pair_mask_bwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], mask_argsort_bwd_splits: List[Tensor], mask_output_fwd: Tensor, masks: Tensor, mask_width: int, is_subm: bool, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False) -> None:
def implicit_gemm_backward(allocator, conv_tuner, features: Tensor, filters: Tensor, out_bp: Tensor, pair_fwd: Tensor, pair_bwd: Tensor, pair_mask_fwd_splits: List[Tensor], pair_mask_bwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], mask_argsort_bwd_splits: List[Tensor], mask_output_fwd: Tensor, masks: Tensor,
arch: Tuple[int, int],
mask_width: int, is_subm: bool, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False) -> None:
"""
"""
Args:
Args:
allocator:
allocator:
...
@@ -91,6 +94,7 @@ class ConvGemmOps:
...
@@ -91,6 +94,7 @@ class ConvGemmOps:
mask_argsort_bwd_splits:
mask_argsort_bwd_splits:
mask_output_fwd:
mask_output_fwd:
masks:
masks:
arch:
mask_width:
mask_width:
is_subm:
is_subm:
stream_int:
stream_int:
...
...
spconv/csrc/sparse/convops.py
View file @
0c07559f
...
@@ -1377,6 +1377,8 @@ class ConvGemmOps(pccm.ParameterizedClass):
...
@@ -1377,6 +1377,8 @@ class ConvGemmOps(pccm.ParameterizedClass):
code
.
arg
(
"features, filters, indice_pairs"
,
"tv::Tensor"
)
code
.
arg
(
"features, filters, indice_pairs"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pair_num"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pair_num"
,
"tv::Tensor"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"num_activate_out"
,
"int"
)
code
.
arg
(
"num_activate_out"
,
"int"
)
code
.
arg
(
"inverse"
,
"bool"
,
"false"
)
code
.
arg
(
"inverse"
,
"bool"
,
"false"
)
code
.
arg
(
"subm"
,
"bool"
,
"false"
)
code
.
arg
(
"subm"
,
"bool"
,
"false"
)
...
@@ -1489,7 +1491,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
...
@@ -1489,7 +1491,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
}}
}}
}}
}}
TV_ASSERT_RT_ERR(nhot_profile > 0, "this shouldn't happen");
TV_ASSERT_RT_ERR(nhot_profile > 0, "this shouldn't happen");
auto arch = get_compute_capability();
//
auto arch = get_compute_capability();
auto a_shape = a.shape();
auto a_shape = a.shape();
auto c_shape = c.shape();
auto c_shape = c.shape();
int sac_shuffle_type = static_cast<int>(tv::gemm::ShuffleStrideType::kShuffleAC);
int sac_shuffle_type = static_cast<int>(tv::gemm::ShuffleStrideType::kShuffleAC);
...
@@ -1584,6 +1586,8 @@ class ConvGemmOps(pccm.ParameterizedClass):
...
@@ -1584,6 +1586,8 @@ class ConvGemmOps(pccm.ParameterizedClass):
code
.
arg
(
"all_w_is_krsc, filter_hwio"
,
"bool"
)
code
.
arg
(
"all_w_is_krsc, filter_hwio"
,
"bool"
)
code
.
arg
(
"features, filters, out_bp, indice_pairs"
,
"tv::Tensor"
)
code
.
arg
(
"features, filters, out_bp, indice_pairs"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pair_num"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pair_num"
,
"tv::Tensor"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"inverse"
,
"bool"
,
"false"
)
code
.
arg
(
"inverse"
,
"bool"
,
"false"
)
code
.
arg
(
"subm"
,
"bool"
,
"false"
)
code
.
arg
(
"subm"
,
"bool"
,
"false"
)
code
.
arg
(
"algo"
,
"int"
,
f
"
{
ConvAlgo
.
Native
.
value
}
"
)
code
.
arg
(
"algo"
,
"int"
,
f
"
{
ConvAlgo
.
Native
.
value
}
"
)
...
@@ -1594,6 +1598,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
...
@@ -1594,6 +1598,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
std::vector<int64_t> filter_shape_per_kv;
std::vector<int64_t> filter_shape_per_kv;
auto prev_filter_shape_vec = filters.shape_vector();
auto prev_filter_shape_vec = filters.shape_vector();
bool is_KC_not_CK;
bool is_KC_not_CK;
if (!all_w_is_krsc){{
if (!all_w_is_krsc){{
kv_dim = 0;
kv_dim = 0;
is_KC_not_CK = !filter_hwio;
is_KC_not_CK = !filter_hwio;
...
@@ -1700,7 +1705,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
...
@@ -1700,7 +1705,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
}}
}}
}}
}}
TV_ASSERT_RT_ERR(nhot_profile > 0, "this shouldn't happen");
TV_ASSERT_RT_ERR(nhot_profile > 0, "this shouldn't happen");
auto arch = get_compute_capability();
//
auto arch = get_compute_capability();
int sac_shuffle_type = static_cast<int>(tv::gemm::ShuffleStrideType::kShuffleAC);
int sac_shuffle_type = static_cast<int>(tv::gemm::ShuffleStrideType::kShuffleAC);
int sab_shuffle_type = static_cast<int>(tv::gemm::ShuffleStrideType::kShuffleAB);
int sab_shuffle_type = static_cast<int>(tv::gemm::ShuffleStrideType::kShuffleAB);
...
@@ -1899,6 +1904,8 @@ class ConvGemmOps(pccm.ParameterizedClass):
...
@@ -1899,6 +1904,8 @@ class ConvGemmOps(pccm.ParameterizedClass):
"std::vector<tv::Tensor>"
)
"std::vector<tv::Tensor>"
)
code
.
arg
(
"num_activate_out"
,
"int"
)
code
.
arg
(
"num_activate_out"
,
"int"
)
code
.
arg
(
"masks"
,
"tv::Tensor"
)
code
.
arg
(
"masks"
,
"tv::Tensor"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"is_train, is_subm"
,
"bool"
,
"false"
)
code
.
arg
(
"is_train, is_subm"
,
"bool"
,
"false"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
arg
(
"timer"
,
"tv::CUDAKernelTimer"
,
"tv::CUDAKernelTimer(false)"
,
code
.
arg
(
"timer"
,
"tv::CUDAKernelTimer"
,
"tv::CUDAKernelTimer(false)"
,
...
@@ -1926,7 +1933,11 @@ class ConvGemmOps(pccm.ParameterizedClass):
...
@@ -1926,7 +1933,11 @@ class ConvGemmOps(pccm.ParameterizedClass):
out_features = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
OutFeatures
)
}
,
out_features = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
OutFeatures
)
}
,
{{num_activate_out, out_channel}}, features.dtype(), features.device(), stream_int);
{{num_activate_out, out_channel}}, features.dtype(), features.device(), stream_int);
}}
}}
auto arch = get_compute_capability();
// auto start_ev = tv::CUDAEvent();
// start_ev.record(stream_int);
// auto arch = get_compute_capability();
constexpr auto kForwardInt = static_cast<int>(tv::gemm::ConvOpType::kForward);
constexpr auto kForwardInt = static_cast<int>(tv::gemm::ConvOpType::kForward);
constexpr auto kChannelLastInt = static_cast<int>(tv::gemm::ConvLayoutType::kChannelLast);
constexpr auto kChannelLastInt = static_cast<int>(tv::gemm::ConvLayoutType::kChannelLast);
auto tuned_res_exist = conv_tuner.get_tuned_algo(
auto tuned_res_exist = conv_tuner.get_tuned_algo(
...
@@ -1959,6 +1970,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
...
@@ -1959,6 +1970,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
fp32_accum);
fp32_accum);
tune_res = std::get<0>(tune_res_time);
tune_res = std::get<0>(tune_res_time);
}}
}}
int mask_width = tune_res.algo_desp.tile_shape[0];
int mask_width = tune_res.algo_desp.tile_shape[0];
tv::Tensor mask_output_fwd;
tv::Tensor mask_output_fwd;
std::vector<tv::Tensor> mask_output_fwd_splits;
std::vector<tv::Tensor> mask_output_fwd_splits;
...
@@ -1974,6 +1986,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
...
@@ -1974,6 +1986,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
mask_output_fwd_splits.push_back(tv::Tensor());
mask_output_fwd_splits.push_back(tv::Tensor());
}}
}}
}}
}}
for (int j = 0; j < num_split; ++j){{
for (int j = 0; j < num_split; ++j){{
float beta = j == 0 ? 0 : 1;
float beta = j == 0 ? 0 : 1;
conv_tuner.run_with_tuned_result(
conv_tuner.run_with_tuned_result(
...
@@ -1995,6 +2008,11 @@ class ConvGemmOps(pccm.ParameterizedClass):
...
@@ -1995,6 +2008,11 @@ class ConvGemmOps(pccm.ParameterizedClass):
false, // verbose
false, // verbose
timer);
timer);
}}
}}
// auto end_ev = tv::CUDAEvent();
// end_ev.record(stream_int);
// tv::ssprint(tune_res.algo_desp.__repr__(), "WTF", exists,
// features.shape(), filters.shape(), out_features.shape(), tv::CUDAEvent::sync_and_duration(start_ev, end_ev));
return mask_width;
return mask_width;
"""
)
"""
)
return
code
.
ret
(
"int"
)
return
code
.
ret
(
"int"
)
...
@@ -2013,6 +2031,8 @@ class ConvGemmOps(pccm.ParameterizedClass):
...
@@ -2013,6 +2031,8 @@ class ConvGemmOps(pccm.ParameterizedClass):
code
.
arg
(
"mask_output_fwd"
,
"tv::Tensor"
)
code
.
arg
(
"mask_output_fwd"
,
"tv::Tensor"
)
code
.
arg
(
"masks"
,
"tv::Tensor"
)
code
.
arg
(
"masks"
,
"tv::Tensor"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"mask_width"
,
"int"
)
code
.
arg
(
"mask_width"
,
"int"
)
code
.
arg
(
"is_subm"
,
"bool"
)
code
.
arg
(
"is_subm"
,
"bool"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
...
@@ -2056,7 +2076,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
...
@@ -2056,7 +2076,7 @@ class ConvGemmOps(pccm.ParameterizedClass):
constexpr auto kChannelLastInt = static_cast<int>(tv::gemm::ConvLayoutType::kChannelLast);
constexpr auto kChannelLastInt = static_cast<int>(tv::gemm::ConvLayoutType::kChannelLast);
auto arch = get_compute_capability();
//
auto arch = get_compute_capability();
auto dgrad_tuned_res_exist = conv_tuner.get_tuned_algo(
auto dgrad_tuned_res_exist = conv_tuner.get_tuned_algo(
kBackwardInputInt,
kBackwardInputInt,
...
...
spconv/pytorch/ops.py
View file @
0c07559f
...
@@ -419,10 +419,16 @@ def get_indice_pairs_implicit_gemm(
...
@@ -419,10 +419,16 @@ def get_indice_pairs_implicit_gemm(
is_mask_split
=
algo
==
ConvAlgo
.
MaskSplitImplicitGemm
is_mask_split
=
algo
==
ConvAlgo
.
MaskSplitImplicitGemm
mask_split_count
=
2
if
is_mask_split
else
1
mask_split_count
=
2
if
is_mask_split
else
1
if
subm
:
if
subm
:
if
is_train
:
pair
=
torch
.
full
((
2
,
kv
,
indices
.
shape
[
0
]),
pair
=
torch
.
full
((
2
,
kv
,
indices
.
shape
[
0
]),
-
1
,
-
1
,
dtype
=
indices
.
dtype
,
dtype
=
indices
.
dtype
,
device
=
indices
.
device
)
device
=
indices
.
device
)
else
:
pair
=
torch
.
full
((
1
,
kv
,
indices
.
shape
[
0
]),
-
1
,
dtype
=
indices
.
dtype
,
device
=
indices
.
device
)
else
:
else
:
# for regular conv, pair-in not equal to pair-out
# for regular conv, pair-in not equal to pair-out
pair
=
torch
.
full
((
kv
,
indices
.
shape
[
0
]),
pair
=
torch
.
full
((
kv
,
indices
.
shape
[
0
]),
...
@@ -476,6 +482,7 @@ def get_indice_pairs_implicit_gemm(
...
@@ -476,6 +482,7 @@ def get_indice_pairs_implicit_gemm(
ksize
=
ksize
,
ksize
=
ksize
,
dilation
=
dilation
,
dilation
=
dilation
,
indice_pair_mask
=
pair_mask_tv
,
indice_pair_mask
=
pair_mask_tv
,
backward
=
is_train
,
stream_int
=
stream
)
stream_int
=
stream
)
# torch.cuda.synchronize()
# torch.cuda.synchronize()
# print("SUBM0", time.time() - t)
# print("SUBM0", time.time() - t)
...
@@ -505,10 +512,12 @@ def get_indice_pairs_implicit_gemm(
...
@@ -505,10 +512,12 @@ def get_indice_pairs_implicit_gemm(
CONV
.
stream_synchronize
(
stream
)
CONV
.
stream_synchronize
(
stream
)
print
(
"SUBM"
,
time
.
time
()
-
t
)
print
(
"SUBM"
,
time
.
time
()
-
t
)
if
is_train
:
return
(
out_inds
,
indice_num_per_loc
,
pair
[
0
],
pair
[
1
],
return
(
out_inds
,
indice_num_per_loc
,
pair
[
0
],
pair
[
1
],
pair_mask_in_splits
,
[],
mask_argsort_in_splits
,
[],
masks
)
pair_mask_in_splits
,
[],
mask_argsort_in_splits
,
[],
masks
)
else
:
return
(
out_inds
,
indice_num_per_loc
,
pair
[
0
],
torch
.
Tensor
(),
pair_mask_in_splits
,
[],
mask_argsort_in_splits
,
[],
masks
)
else
:
else
:
if
DEBUG
:
if
DEBUG
:
...
@@ -753,11 +762,15 @@ def indice_conv(features: torch.Tensor,
...
@@ -753,11 +762,15 @@ def indice_conv(features: torch.Tensor,
indice_pair_num_tv
=
torch_tensor_to_tv
(
indice_pair_num
)
indice_pair_num_tv
=
torch_tensor_to_tv
(
indice_pair_num
)
filters_tv
=
torch_tensor_to_tv
(
filters
)
filters_tv
=
torch_tensor_to_tv
(
filters
)
stream
=
0
stream
=
0
arch
=
(
0
,
0
)
if
features
.
is_cuda
:
if
features
.
is_cuda
:
# plain get_arch by cuda api is VERY SLOW.
arch
=
get_arch
()
stream
=
get_current_stream
()
stream
=
get_current_stream
()
ConvGemmOps
.
indice_conv
(
alloc
,
ext_mm
,
GEMM_CPP
,
ALL_WEIGHT_IS_KRSC
,
ConvGemmOps
.
indice_conv
(
alloc
,
ext_mm
,
GEMM_CPP
,
ALL_WEIGHT_IS_KRSC
,
FILTER_HWIO
,
features_tv
,
filters_tv
,
FILTER_HWIO
,
features_tv
,
filters_tv
,
indice_pairs_tv
,
indice_pair_num_tv
,
indice_pairs_tv
,
indice_pair_num_tv
,
arch
,
num_activate_out
,
inverse
,
subm
,
algo
.
value
,
num_activate_out
,
inverse
,
subm
,
algo
.
value
,
stream
)
stream
)
out_features
=
alloc
.
allocated
[
AllocKeys
.
OutFeatures
]
out_features
=
alloc
.
allocated
[
AllocKeys
.
OutFeatures
]
...
@@ -996,12 +1009,16 @@ def indice_conv_backward(features: torch.Tensor,
...
@@ -996,12 +1009,16 @@ def indice_conv_backward(features: torch.Tensor,
indice_pair_num_tv
=
torch_tensor_to_tv
(
indice_pair_num
)
indice_pair_num_tv
=
torch_tensor_to_tv
(
indice_pair_num
)
filters_tv
=
torch_tensor_to_tv
(
filters
)
filters_tv
=
torch_tensor_to_tv
(
filters
)
stream
=
0
stream
=
0
arch
=
(
0
,
0
)
if
features
.
is_cuda
:
if
features
.
is_cuda
:
stream
=
get_current_stream
()
stream
=
get_current_stream
()
arch
=
get_arch
()
ConvGemmOps
.
indice_conv_backward
(
alloc
,
ext_mm
,
GEMM_CPP
,
ConvGemmOps
.
indice_conv_backward
(
alloc
,
ext_mm
,
GEMM_CPP
,
ALL_WEIGHT_IS_KRSC
,
FILTER_HWIO
,
ALL_WEIGHT_IS_KRSC
,
FILTER_HWIO
,
features_tv
,
filters_tv
,
out_bp_tv
,
features_tv
,
filters_tv
,
out_bp_tv
,
indice_pairs_tv
,
indice_pair_num_tv
,
indice_pairs_tv
,
indice_pair_num_tv
,
arch
,
inverse
,
subm
,
algo
.
value
,
stream
)
inverse
,
subm
,
algo
.
value
,
stream
)
din
=
alloc
.
allocated
[
AllocKeys
.
DIn
]
din
=
alloc
.
allocated
[
AllocKeys
.
DIn
]
df
=
alloc
.
allocated
[
AllocKeys
.
DFilters
]
df
=
alloc
.
allocated
[
AllocKeys
.
DFilters
]
...
@@ -1347,10 +1364,12 @@ def implicit_gemm(features: torch.Tensor,
...
@@ -1347,10 +1364,12 @@ def implicit_gemm(features: torch.Tensor,
auto_fp32_accum
=
fp32_accum
is
None
auto_fp32_accum
=
fp32_accum
is
None
if
fp32_accum
is
None
:
if
fp32_accum
is
None
:
fp32_accum
=
False
fp32_accum
=
False
arch
=
get_arch
()
mask_width
=
ConvGemmOps
.
implicit_gemm
(
mask_width
=
ConvGemmOps
.
implicit_gemm
(
alloc
,
CONV_CPP
,
features_tv
,
filters_tv
,
pair_fwd_tv
,
alloc
,
CONV_CPP
,
features_tv
,
filters_tv
,
pair_fwd_tv
,
pair_mask_fwd_splits_tv
,
mask_argsort_fwd_splits_tv
,
pair_mask_fwd_splits_tv
,
mask_argsort_fwd_splits_tv
,
num_activate_out
,
mask_tv
,
is_train
,
is_subm
,
stream
,
timer_cpp
,
num_activate_out
,
mask_tv
,
arch
,
is_train
,
is_subm
,
stream
,
timer_cpp
,
auto_fp32_accum
,
fp32_accum
)
auto_fp32_accum
,
fp32_accum
)
out_features
=
alloc
.
allocated
[
AllocKeys
.
OutFeatures
]
out_features
=
alloc
.
allocated
[
AllocKeys
.
OutFeatures
]
mask_output_fwd
=
alloc
.
allocated
.
get
(
AllocKeys
.
MaskOutputFwd
,
None
)
mask_output_fwd
=
alloc
.
allocated
.
get
(
AllocKeys
.
MaskOutputFwd
,
None
)
...
@@ -1441,7 +1460,7 @@ def implicit_gemm(features: torch.Tensor,
...
@@ -1441,7 +1460,7 @@ def implicit_gemm(features: torch.Tensor,
# CONV.stream_synchronize(stream)
# CONV.stream_synchronize(stream)
# t = time.time()
# t = time.time()
#
print(tune_res.algo_desp)
print
(
tune_res
.
algo_desp
,
"REF"
,
features_tv
.
shape
,
filters
.
shape
)
# with tv.measure_and_print("f16 time"):
# with tv.measure_and_print("f16 time"):
with
timer
.
record
(
"implicit_gemm"
,
stream
):
with
timer
.
record
(
"implicit_gemm"
,
stream
):
for
j
in
range
(
num_split
):
for
j
in
range
(
num_split
):
...
@@ -1613,11 +1632,13 @@ def implicit_gemm_backward(features: torch.Tensor,
...
@@ -1613,11 +1632,13 @@ def implicit_gemm_backward(features: torch.Tensor,
auto_fp32_accum
=
fp32_accum
is
None
auto_fp32_accum
=
fp32_accum
is
None
if
fp32_accum
is
None
:
if
fp32_accum
is
None
:
fp32_accum
=
False
fp32_accum
=
False
arch
=
get_arch
()
ConvGemmOps
.
implicit_gemm_backward
(
ConvGemmOps
.
implicit_gemm_backward
(
alloc
,
CONV_CPP
,
features_tv
,
filters_tv
,
out_bp_tv
,
pair_fwd_tv
,
alloc
,
CONV_CPP
,
features_tv
,
filters_tv
,
out_bp_tv
,
pair_fwd_tv
,
pair_bwd_tv
,
pair_mask_fwd_splits_tv
,
pair_mask_bwd_splits_tv
,
pair_bwd_tv
,
pair_mask_fwd_splits_tv
,
pair_mask_bwd_splits_tv
,
mask_argsort_fwd_splits_tv
,
mask_argsort_bwd_splits_tv
,
mask_argsort_fwd_splits_tv
,
mask_argsort_bwd_splits_tv
,
mask_output_fwd_tv
,
mask_tv
,
mask_width
,
is_subm
,
stream
,
mask_output_fwd_tv
,
mask_tv
,
arch
,
mask_width
,
is_subm
,
stream
,
timer_cpp
,
auto_fp32_accum
,
fp32_accum
)
timer_cpp
,
auto_fp32_accum
,
fp32_accum
)
din
=
alloc
.
allocated
[
AllocKeys
.
DIn
]
din
=
alloc
.
allocated
[
AllocKeys
.
DIn
]
dfilters
=
alloc
.
allocated
[
AllocKeys
.
DFilters
]
dfilters
=
alloc
.
allocated
[
AllocKeys
.
DFilters
]
...
...
test/benchmark.py
View file @
0c07559f
...
@@ -113,7 +113,7 @@ class Net(nn.Module):
...
@@ -113,7 +113,7 @@ class Net(nn.Module):
# nn.BatchNorm1d(32),
# nn.BatchNorm1d(32),
# nn.ReLU(),
# nn.ReLU(),
# spconv.SparseConv3d(64, 64, 2, 2, bias=False, indice_key="m0"),
# spconv.SparseConv3d(64, 64, 2, 2, bias=False, indice_key="m0"),
spconv
.
SparseMaxPool3d
(
2
,
2
,
algo
=
pool_algo
,
record_voxel_count
=
True
),
spconv
.
SparseMaxPool3d
(
2
,
2
,
algo
=
pool_algo
),
spconv
.
SubMConv3d
(
64
,
spconv
.
SubMConv3d
(
64
,
96
,
96
,
3
,
3
,
...
@@ -312,7 +312,7 @@ def sort_bench():
...
@@ -312,7 +312,7 @@ def sort_bench():
for
i
in
range
(
10
):
for
i
in
range
(
10
):
a_tv_1
=
a_tv
.
clone
()
a_tv_1
=
a_tv
.
clone
()
SpconvOps
.
sort_1d_by_key
(
a_tv_1
[
0
],
mask_argsort_tv
[
0
])
SpconvOps
.
sort_1d_by_key
(
a_tv_1
[
0
],
mask_argsort_tv
[
0
])
import
json
def
main
():
def
main
():
import
pickle
import
pickle
...
@@ -332,7 +332,8 @@ def main():
...
@@ -332,7 +332,8 @@ def main():
voxels_th
=
torch
.
from_numpy
(
voxels
).
to
(
device
).
to
(
dtype
)
voxels_th
=
torch
.
from_numpy
(
voxels
).
to
(
device
).
to
(
dtype
)
coors_th
=
torch
.
from_numpy
(
coors
).
to
(
device
).
int
()
coors_th
=
torch
.
from_numpy
(
coors
).
to
(
device
).
int
()
voxels_th
.
requires_grad
=
True
voxels_th
.
requires_grad
=
True
algo
=
spconv
.
ConvAlgo
.
Native
algo
=
spconv
.
ConvAlgo
.
MaskImplicitGemm
print
(
"ALGO"
)
# 3080 Laptop
# 3080 Laptop
# MaskImpGemm: 11.2ms
# MaskImpGemm: 11.2ms
# MaskSplitImpGemm: 12.2ms
# MaskSplitImpGemm: 12.2ms
...
@@ -355,7 +356,7 @@ def main():
...
@@ -355,7 +356,7 @@ def main():
# MaskImpGemm: 51.0ms
# MaskImpGemm: 51.0ms
# MaskSplitImpGemm: 41.1ms
# MaskSplitImpGemm: 41.1ms
# algo = None
# algo = None
net
=
Net
(
spatial_shape
,
algo
).
to
(
device
).
eval
().
to
(
dtype
).
train
()
net
=
Net
(
spatial_shape
,
algo
).
to
(
device
).
eval
().
to
(
dtype
)
#
.train()
# net.load_state_dict(net.state_dict())
# net.load_state_dict(net.state_dict())
spconv
.
assign_name_for_sparse_modules
(
net
)
spconv
.
assign_name_for_sparse_modules
(
net
)
print
(
coors_th
.
shape
)
print
(
coors_th
.
shape
)
...
@@ -368,13 +369,13 @@ def main():
...
@@ -368,13 +369,13 @@ def main():
print
(
out
.
spatial_shape
,
out
.
features
.
mean
(),
out
.
features
.
max
(),
print
(
out
.
spatial_shape
,
out
.
features
.
mean
(),
out
.
features
.
max
(),
out
.
features
.
min
())
out
.
features
.
min
())
times
=
[]
times
=
[]
show_metrics
=
False
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
i
in
range
(
20
):
for
i
in
range
(
20
):
print
(
"------------"
)
print
(
"------------"
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t
=
time
.
time
()
t
=
time
.
time
()
out_nograd
=
net
(
voxels_th
,
coors_th
,
1
,
False
)
out_nograd
=
net
(
voxels_th
,
coors_th
,
1
,
show_metrics
)
timer
=
out_nograd
.
_timer
# res = timer.collect_by_name("forward", timer.get_all_pair_time())
# res = timer.collect_by_name("forward", timer.get_all_pair_time())
# res2 = timer.collect_by_name("forward0", timer.get_all_pair_time())
# res2 = timer.collect_by_name("forward0", timer.get_all_pair_time())
...
@@ -385,6 +386,12 @@ def main():
...
@@ -385,6 +386,12 @@ def main():
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
# sort_bench()
# sort_bench()
times
.
append
(
time
.
time
()
-
t
)
times
.
append
(
time
.
time
()
-
t
)
if
show_metrics
:
timer
=
out_nograd
.
_timer
items
=
list
(
timer
.
get_all_pair_time
().
items
())
items
.
sort
(
key
=
lambda
x
:
x
[
0
])
print
(
json
.
dumps
(
dict
(
items
),
indent
=
2
))
# state = net.state_dict()
# state = net.state_dict()
# state.pop("net.2.max_num_voxels_during_training")
# state.pop("net.2.max_num_voxels_during_training")
# net.load_state_dict(state)
# net.load_state_dict(state)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment