Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
899008fa
Commit
899008fa
authored
Jul 20, 2022
by
yan.yan
Browse files
working on c++ only
parent
f78575ea
Changes
31
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3404 additions
and
327 deletions
+3404
-327
docs/DEVELOPMENT.md
docs/DEVELOPMENT.md
+0
-25
setup.py
setup.py
+22
-3
spconv/algo.py
spconv/algo.py
+187
-67
spconv/algocore.py
spconv/algocore.py
+13
-3
spconv/benchmark/me.py
spconv/benchmark/me.py
+0
-24
spconv/benchmark/thsp.py
spconv/benchmark/thsp.py
+0
-24
spconv/build.py
spconv/build.py
+35
-3
spconv/constants.py
spconv/constants.py
+30
-1
spconv/core.py
spconv/core.py
+177
-27
spconv/core_cc/csrc/sparse/all/__init__.pyi
spconv/core_cc/csrc/sparse/all/__init__.pyi
+43
-10
spconv/core_cc/csrc/sparse/alloc.pyi
spconv/core_cc/csrc/sparse/alloc.pyi
+18
-4
spconv/core_cc/csrc/sparse/convops/__init__.pyi
spconv/core_cc/csrc/sparse/convops/__init__.pyi
+96
-0
spconv/core_cc/csrc/sparse/convops/convops.pyi
spconv/core_cc/csrc/sparse/convops/convops.pyi
+126
-0
spconv/core_cc/csrc/sparse/convops/gemmops.pyi
spconv/core_cc/csrc/sparse/convops/gemmops.pyi
+107
-0
spconv/core_cc/csrc/sparse/convops/spops.pyi
spconv/core_cc/csrc/sparse/convops/spops.pyi
+101
-0
spconv/core_cc/cumm/common.pyi
spconv/core_cc/cumm/common.pyi
+7
-0
spconv/core_cc/cumm/gemm/main.pyi
spconv/core_cc/cumm/gemm/main.pyi
+3
-2
spconv/csrc/sparse/all.py
spconv/csrc/sparse/all.py
+217
-79
spconv/csrc/sparse/alloc.py
spconv/csrc/sparse/alloc.py
+190
-7
spconv/csrc/sparse/convops.py
spconv/csrc/sparse/convops.py
+2032
-48
No files found.
docs/DEVELOPMENT.md
deleted
100644 → 0
View file @
f78575ea
<!--
Copyright 2021 Yan Yan
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# How to develop spconv 2.x
## First step
spconv 2.x is written in a unique c++ framework
```pccm```
. read
[
pccm guide
](
)
to learn how to use
```pccm```
.
It's recommend to uninstall spconv and cumm installed by pip, then install spconv and cumm both in editable mode (
```pip install -e .```
)
## Architecture
\ No newline at end of file
setup.py
View file @
899008fa
...
...
@@ -159,6 +159,9 @@ if disable_jit is not None and disable_jit == "1":
from
spconv.csrc.utils
import
BoxOps
from
spconv.csrc.hash.core
import
HashTable
from
cumm.common
import
CompileInfo
from
spconv.csrc.sparse.alloc
import
ExternalAllocator
from
spconv.csrc.sparse.convops
import
GemmTunerSimple
,
ExternalSpconvMatmul
from
spconv.csrc.sparse.convops
import
ConvTunerSimple
,
ConvGemmOps
cu
=
GemmMainUnitTest
(
SHUFFLE_SIMT_PARAMS
+
SHUFFLE_VOLTA_PARAMS
+
SHUFFLE_TURING_PARAMS
)
convcu
=
ConvMainUnitTest
(
IMPLGEMM_SIMT_PARAMS
+
IMPLGEMM_VOLTA_PARAMS
+
IMPLGEMM_TURING_PARAMS
)
...
...
@@ -172,14 +175,30 @@ if disable_jit is not None and disable_jit == "1":
std
=
"c++14"
else
:
std
=
"c++17"
cus
=
[
cu
,
convcu
,
SpconvOps
(),
BoxOps
(),
HashTable
(),
CompileInfo
()]
if
CUMM_CPU_ONLY_BUILD
:
cus
=
[
SpconvOps
(),
BoxOps
(),
HashTable
(),
CompileInfo
()]
gemmtuner
=
GemmTunerSimple
(
cu
)
gemmtuner
.
namespace
=
"csrc.sparse.convops.gemmops"
convtuner
=
ConvTunerSimple
(
convcu
)
convtuner
.
namespace
=
"csrc.sparse.convops.convops"
convops
=
ConvGemmOps
(
gemmtuner
,
convtuner
)
convops
.
namespace
=
"csrc.sparse.convops.spops"
else
:
gemmtuner
=
GemmTunerSimple
(
None
)
gemmtuner
.
namespace
=
"csrc.sparse.convops.gemmops"
convtuner
=
ConvTunerSimple
(
None
)
convtuner
.
namespace
=
"csrc.sparse.convops.convops"
convops
=
ConvGemmOps
(
gemmtuner
,
convtuner
)
convops
.
namespace
=
"csrc.sparse.convops.spops"
cus
=
[
gemmtuner
,
convtuner
,
convops
,
SpconvOps
(),
BoxOps
(),
HashTable
(),
CompileInfo
(),
ExternalAllocator
(),
ExternalSpconvMatmul
()]
if
not
CUMM_CPU_ONLY_BUILD
:
cus
.
extend
([
cu
,
convcu
])
ext_modules
:
List
[
Extension
]
=
[
PCCMExtension
(
cus
,
"spconv/core_cc"
,
Path
(
__file__
).
resolve
().
parent
/
"spconv"
,
objects_folder
=
"objects"
,
std
=
std
,
disable_pch
=
True
,
verbose
=
True
)
...
...
spconv/algo.py
View file @
899008fa
...
...
@@ -37,7 +37,7 @@ from cumm import dtypes
from
spconv.constants
import
(
NDIM_DONT_CARE
,
SPCONV_BWD_SPLITK
,
SPCONV_NVRTC_MODE
,
SPCONV_DEBUG_NVRTC_KERNELS
)
from
spconv.core
import
ALL_IMPGEMM_PARAMS
,
AlgoHint
,
ConvAlgo
from
spconv.core
import
ALL_IMPGEMM_PARAMS
,
AlgoHint
,
ConvAlgo
,
ALL_NATIVE_PARAMS
from
spconv.core_cc.cumm.conv.main
import
ConvMainUnitTest
from
spconv.core_cc.cumm.gemm.main
import
GemmMainUnitTest
from
spconv.cppconstants
import
COMPILED_CUDA_ARCHS
...
...
@@ -49,14 +49,17 @@ from spconv import algocore
from
cumm.conv.main
import
gen_gemm_kernels
as
gen_conv_kernels
from
cumm.gemm.main
import
gen_gemm_kernels
from
spconv.core_cc.csrc.sparse.convops
import
GemmTuneResult
,
ConvTuneResult
from
spconv.core_cc.csrc.sparse.convops.gemmops
import
GemmTunerSimple
as
GemmTunerSimpleBase
from
spconv.core_cc.csrc.sparse.convops.convops
import
ConvTunerSimple
as
ConvTunerSimpleBase
ALL_ALGO_DESPS
=
GemmMainUnitTest
.
get_all_algo_desp
()
ALL_CONV_ALGO_DESPS
=
ConvMainUnitTest
.
get_all_conv_algo_desp
()
_GEMM_STATIC_KEY
=
Tuple
[
bool
,
bool
,
bool
,
int
,
int
,
int
,
str
,
str
]
_GEMM_STATIC_KEY
=
Tuple
[
bool
,
bool
,
bool
,
int
,
int
,
int
,
int
,
str
]
class
SimpleGemmAlgoMeta
:
def
__init__
(
self
,
tile_ms
:
List
[
int
],
tile_ns
:
List
[
int
],
tile_ks
:
List
[
int
],
tile_shape_to_algos
:
Dict
[
int
,
List
[
int
]])
->
None
:
...
...
@@ -67,19 +70,29 @@ class SimpleGemmAlgoMeta:
class
BestAlgoByProfile
:
def
__init__
(
self
,
algo_desp
:
GemmAlgoDesp
,
arch
:
Tuple
[
int
,
int
],
splitk
:
int
=
1
)
->
None
:
def
__init__
(
self
,
algo_desp
:
GemmAlgoDesp
,
arch
:
Tuple
[
int
,
int
],
splitk
:
int
=
1
)
->
None
:
self
.
algo_desp
=
algo_desp
self
.
splitk
=
splitk
self
.
arch
=
arch
class
BestConvAlgoByProfile
:
def
__init__
(
self
,
algo_desp
:
ConvAlgoDesp
,
arch
:
Tuple
[
int
,
int
],
splitk
:
int
=
1
)
->
None
:
def
__init__
(
self
,
algo_desp
:
ConvAlgoDesp
,
arch
:
Tuple
[
int
,
int
],
splitk
:
int
=
1
)
->
None
:
self
.
algo_desp
=
algo_desp
self
.
splitk
=
splitk
self
.
arch
=
arch
def
_get_nvrtc_params
(
mod
:
CummNVRTCModule
,
ker
:
Union
[
GemmKernel
,
ConvKernel
],
kernel_name
:
str
):
def
_get_nvrtc_params
(
mod
:
CummNVRTCModule
,
ker
:
Union
[
GemmKernel
,
ConvKernel
],
kernel_name
:
str
):
nvrtc_mode
=
SPCONV_NVRTC_MODE
nvrtc_params
=
tv
.
gemm
.
NVRTCParams
()
nvrtc_params
.
cumodule
=
mod
.
get_cpp_object
()
...
...
@@ -89,8 +102,7 @@ def _get_nvrtc_params(mod: CummNVRTCModule, ker: Union[GemmKernel, ConvKernel],
ns
=
ker
.
namespace
if
nvrtc_mode
==
NVRTCMode
.
DynamicParallism
:
nvrtc_params
.
kernel_name
=
mod
.
get_lowered_name
(
f
"
{
ns
}
::nvrtc_kernel"
)
nvrtc_params
.
kernel_name
=
mod
.
get_lowered_name
(
f
"
{
ns
}
::nvrtc_kernel"
)
elif
nvrtc_mode
==
NVRTCMode
.
KernelAndCPU
:
nvrtc_params
.
kernel_name
=
mod
.
get_lowered_name
(
f
"
{
ns
}
::
{
kernel_name
}
"
)
...
...
@@ -101,8 +113,10 @@ def _get_nvrtc_params(mod: CummNVRTCModule, ker: Union[GemmKernel, ConvKernel],
nvrtc_params
.
param_storage
=
tv
.
empty
([
nvrtc_params
.
param_size
],
tv
.
uint8
,
0
)
nvrtc_params
.
param_storage_cpu
=
tv
.
empty
(
[
nvrtc_params
.
param_size
],
tv
.
uint8
,
-
1
,
pinned
=
True
)
nvrtc_params
.
param_storage_cpu
=
tv
.
empty
([
nvrtc_params
.
param_size
],
tv
.
uint8
,
-
1
,
pinned
=
True
)
elif
nvrtc_mode
==
NVRTCMode
.
Direct
:
nvrtc_params
.
kernel_name
=
mod
.
get_lowered_name
(
f
"
{
ns
}
::
{
kernel_name
}
"
)
...
...
@@ -120,9 +134,84 @@ def _get_nvrtc_params(mod: CummNVRTCModule, ker: Union[GemmKernel, ConvKernel],
raise
NotImplementedError
return
nvrtc_params
class
GemmTunerSimple
(
GemmTunerSimpleBase
):
def
__init__
(
self
,
desps
:
List
[
GemmAlgoDesp
])
->
None
:
super
().
__init__
(
desps
)
self
.
_nvrtc_caches
:
Dict
[
Tuple
[
str
,
Tuple
[
int
,
int
],
int
],
NVRTCParams
]
=
{}
def
_compile_nvrtc_module
(
self
,
desp
:
GemmAlgoDesp
):
params
=
algocore
.
get_gemm_param_from_desp
(
desp
)
kernel
=
gen_gemm_kernels
(
params
,
SPCONV_NVRTC_MODE
)
kernel
.
namespace
=
"spconv"
custom_names
=
[]
if
SPCONV_NVRTC_MODE
==
NVRTCMode
.
ConstantMemory
:
custom_names
=
[
f
"&
{
kernel
.
namespace
}
::
{
NVRTCConstants
.
CONSTANT_PARAM_KEY
}
"
]
cudadevrt
=
""
if
SPCONV_NVRTC_MODE
==
NVRTCMode
.
DynamicParallism
:
cudadevrt_p
=
get_cudadevrt_path
()
assert
cudadevrt_p
is
not
None
,
"DynamicParallism must have cudadevrt"
cudadevrt
=
str
(
cudadevrt_p
)
mod
=
CummNVRTCModule
([
kernel
],
cudadevrt_path
=
cudadevrt
,
custom_names
=
custom_names
)
mod
.
load
()
return
mod
,
kernel
def
cached_get_nvrtc_params
(
self
,
desp
:
GemmAlgoDesp
,
arch
:
Tuple
[
int
,
int
],
stream_int
:
int
)
->
NVRTCParams
:
key
=
(
str
(
desp
),
arch
,
stream_int
)
if
key
in
self
.
_nvrtc_caches
:
return
self
.
_nvrtc_caches
[
key
]
mod
,
ker
=
self
.
_compile_nvrtc_module
(
desp
)
nvrtc_params
=
_get_nvrtc_params
(
mod
,
ker
,
"gemm_kernel"
)
self
.
_nvrtc_caches
[
key
]
=
nvrtc_params
return
nvrtc_params
class
ConvTunerSimple
(
ConvTunerSimpleBase
):
def
__init__
(
self
,
desps
:
List
[
ConvAlgoDesp
])
->
None
:
super
().
__init__
(
desps
)
self
.
_nvrtc_caches
:
Dict
[
Tuple
[
str
,
Tuple
[
int
,
int
],
int
],
NVRTCParams
]
=
{}
def
_compile_nvrtc_module
(
self
,
desp
:
ConvAlgoDesp
):
params
=
algocore
.
get_conv_param_from_desp
(
desp
)
kernel
=
gen_conv_kernels
(
params
,
SPCONV_NVRTC_MODE
)
kernel
.
namespace
=
"spconv"
custom_names
=
[]
if
SPCONV_NVRTC_MODE
==
NVRTCMode
.
ConstantMemory
:
custom_names
=
[
f
"&
{
kernel
.
namespace
}
::
{
NVRTCConstants
.
CONSTANT_PARAM_KEY
}
"
]
cudadevrt
=
""
if
SPCONV_NVRTC_MODE
==
NVRTCMode
.
DynamicParallism
:
cudadevrt_p
=
get_cudadevrt_path
()
assert
cudadevrt_p
is
not
None
,
"DynamicParallism must have cudadevrt"
cudadevrt
=
str
(
cudadevrt_p
)
mod
=
CummNVRTCModule
([
kernel
],
cudadevrt_path
=
cudadevrt
,
verbose
=
False
,
custom_names
=
custom_names
)
mod
.
load
()
return
mod
,
kernel
def
cached_get_nvrtc_params
(
self
,
desp
:
ConvAlgoDesp
,
arch
:
Tuple
[
int
,
int
],
stream_int
:
int
)
->
NVRTCParams
:
key
=
(
str
(
desp
),
arch
,
stream_int
)
if
key
in
self
.
_nvrtc_caches
:
return
self
.
_nvrtc_caches
[
key
]
mod
,
ker
=
self
.
_compile_nvrtc_module
(
desp
)
print
(
f
"Can't find algo
{
desp
}
in prebuilt. compile with nvrtc..."
)
nvrtc_params
=
_get_nvrtc_params
(
mod
,
ker
,
"conv_kernel"
)
self
.
_nvrtc_caches
[
key
]
=
nvrtc_params
return
nvrtc_params
class
SimpleGemm
:
def
__init__
(
self
,
prebuilt_desps
:
List
[
GemmAlgoDesp
])
->
None
:
all_desps
=
[
algocore
.
get_conv_algo_desp_from_param
(
p
)
for
p
in
ALL_IMPGEMM_PARAMS
]
all_desps
=
[
algocore
.
get_gemm_algo_desp_from_param
(
p
)
for
p
in
ALL_NATIVE_PARAMS
]
self
.
prebuilt_desps
=
prebuilt_desps
self
.
prebuilt_desp_names
=
{
str
(
d
)
for
d
in
prebuilt_desps
}
if
SPCONV_DEBUG_NVRTC_KERNELS
:
...
...
@@ -178,7 +267,9 @@ class SimpleGemm:
kernel
.
namespace
=
"spconv"
custom_names
=
[]
if
SPCONV_NVRTC_MODE
==
NVRTCMode
.
ConstantMemory
:
custom_names
=
[
f
"&
{
kernel
.
namespace
}
::
{
NVRTCConstants
.
CONSTANT_PARAM_KEY
}
"
]
custom_names
=
[
f
"&
{
kernel
.
namespace
}
::
{
NVRTCConstants
.
CONSTANT_PARAM_KEY
}
"
]
cudadevrt
=
""
if
SPCONV_NVRTC_MODE
==
NVRTCMode
.
DynamicParallism
:
cudadevrt_p
=
get_cudadevrt_path
()
...
...
@@ -186,12 +277,12 @@ class SimpleGemm:
cudadevrt
=
str
(
cudadevrt_p
)
mod
=
CummNVRTCModule
([
kernel
],
cudadevrt_path
=
cudadevrt
,
verbose
=
False
,
custom_names
=
custom_names
)
mod
.
load
()
return
mod
,
kernel
def
_cached_get_nvrtc_params
(
self
,
desp
:
GemmAlgoDesp
,
arch
:
Tuple
[
int
,
int
]):
def
_cached_get_nvrtc_params
(
self
,
desp
:
GemmAlgoDesp
,
arch
:
Tuple
[
int
,
int
]):
key
=
(
str
(
desp
),
arch
)
if
key
in
self
.
_nvrtc_caches
:
return
self
.
_nvrtc_caches
[
key
]
...
...
@@ -218,12 +309,15 @@ class SimpleGemm:
trans_c
=
False
avail_algos
=
get_available_algo_str_from_arch
(
arch
)
finally_algos
:
List
[
GemmAlgoDesp
]
=
[]
# print(self.static_key_to_desps)
for
algo
in
avail_algos
:
static_key
=
(
trans_a
,
trans_b
,
trans_c
,
a
.
dtype
,
b
.
dtype
,
c
.
dtype
,
shuffle_type
.
value
,
algo
)
# print(static_key)
desps
=
self
.
static_key_to_desps
.
get
(
static_key
,
None
)
if
desps
is
None
or
len
(
desps
)
==
0
:
continue
# print(desps)
for
desp
in
desps
:
# skip volta tensor op since it is very slow in architectures except volta.
if
arch
>=
(
7
,
5
)
and
desp
.
algo
==
GemmAlgo
.
Volta
.
value
:
...
...
@@ -430,6 +524,7 @@ class SimpleGemm:
best_scatter_params
=
(
-
1
,
-
1
,
-
1
,
-
1
)
all_profile_res
:
List
[
BestAlgoByProfile
]
=
[]
# print(avail)
for
desp
in
avail
:
c_
.
zero_whole_storage_
()
split_k_slices
=
1
...
...
@@ -466,7 +561,8 @@ class SimpleGemm:
times
.
append
(
np
.
mean
(
this_times
[
1
:]))
spk_speeds
.
append
(
times
[
-
1
])
all_profile_res
.
append
(
BestAlgoByProfile
(
desp
,
arch
,
splitk
=
spk
))
all_profile_res
.
append
(
BestAlgoByProfile
(
desp
,
arch
,
splitk
=
spk
))
min_time
=
1000
min_idx
=
-
1
...
...
@@ -490,8 +586,7 @@ class SimpleGemm:
return
res
,
min_time
def
run_with_tuned_result
(
self
,
def
run_with_tuned_result
(
self
,
profile_res
:
BestAlgoByProfile
,
a
:
tv
.
Tensor
,
b
:
tv
.
Tensor
,
...
...
@@ -501,7 +596,7 @@ class SimpleGemm:
trans_c
:
bool
,
arch
:
Tuple
[
int
,
int
],
stream
:
int
,
shuffle_type
:
ShuffleStrideType
=
ShuffleStrideType
.
NoShuffle
,
shuffle_type
:
ShuffleStrideType
,
a_inds
:
tv
.
Tensor
=
tv
.
Tensor
(),
b_inds
:
tv
.
Tensor
=
tv
.
Tensor
(),
c_inds
:
tv
.
Tensor
=
tv
.
Tensor
(),
...
...
@@ -510,7 +605,8 @@ class SimpleGemm:
beta
:
float
=
0.0
,
gather_data
:
tv
.
Tensor
=
tv
.
Tensor
(),
workspace
:
tv
.
Tensor
=
tv
.
Tensor
(),
timer
:
CUDAKernelTimer
=
CUDAKernelTimer
(
False
)):
timer
:
CUDAKernelTimer
=
CUDAKernelTimer
(
False
),
force_nvrtc
:
bool
=
False
):
m
,
n
,
k
=
GemmMainUnitTest
.
extract_mnk
(
a
.
shape
,
b
.
shape
,
trans_a
,
trans_b
,
trans_c
,
shuffle_type
.
value
,
...
...
@@ -526,8 +622,10 @@ class SimpleGemm:
if
profile_res
.
splitk
>
1
:
split_k_slices
=
profile_res
.
splitk
params
=
GemmParams
()
if
algo_desp
.
is_nvrtc
and
str
(
algo_desp
)
not
in
self
.
prebuilt_desp_names
:
params
.
nvrtc_params
=
self
.
_cached_get_nvrtc_params
(
algo_desp
,
profile_res
.
arch
)
is_not_static
=
str
(
algo_desp
)
not
in
self
.
prebuilt_desp_names
if
algo_desp
.
is_nvrtc
and
(
is_not_static
or
force_nvrtc
):
params
.
nvrtc_params
=
self
.
_cached_get_nvrtc_params
(
algo_desp
,
profile_res
.
arch
)
params
.
a
=
a
params
.
b
=
b
...
...
@@ -569,8 +667,12 @@ _CONV_STATIC_KEY = Tuple[int, int, int, int, int, int, int, int, int, str, int]
class
SimpleConv
:
def
__init__
(
self
,
prebuilt_desps
:
List
[
ConvAlgoDesp
])
->
None
:
all_desps
=
[
algocore
.
get_conv_algo_desp_from_param
(
p
)
for
p
in
ALL_IMPGEMM_PARAMS
]
all_desps
=
[
algocore
.
get_conv_algo_desp_from_param
(
p
)
for
p
in
ALL_IMPGEMM_PARAMS
]
self
.
prebuilt_desps
=
prebuilt_desps
self
.
prebuilt_desp_names
=
{
str
(
d
)
for
d
in
prebuilt_desps
}
self
.
prebuilt_desp_names
.
clear
()
...
...
@@ -650,6 +752,7 @@ class SimpleConv:
use_f32_as_accum
=
weight
.
dim
(
0
)
*
kv
>
128
*
27
else
:
use_f32_as_accum
=
fp32_accum
use_f32_as_accum
=
False
for
algo
in
avail_algos
:
static_key
=
(
layout_i
.
layout_type
.
value
,
layout_w
.
layout_type
.
value
,
...
...
@@ -664,7 +767,6 @@ class SimpleConv:
if
arch
>=
(
7
,
5
)
and
desp
.
algo
==
GemmAlgo
.
Volta
.
value
:
continue
if
arch
>=
(
7
,
0
)
and
is_fp16
:
# skip simt fp16 kernels if we have tensor core
if
desp
.
algo
==
GemmAlgo
.
Simt
:
continue
if
use_f32_as_accum
:
...
...
@@ -675,6 +777,7 @@ class SimpleConv:
ldw
=
weight
.
dim
(
-
1
)
ldo
=
out
.
dim
(
-
1
)
mask_width_valid
=
True
if
desp
.
op_type
==
ConvOpType
.
kBackwardWeight
.
value
:
assert
mask_width
>
0
mask_width_valid
=
mask_width
%
desp
.
tile_shape
[
2
]
==
0
...
...
@@ -722,7 +825,9 @@ class SimpleConv:
kernel
.
namespace
=
"spconv"
custom_names
=
[]
if
SPCONV_NVRTC_MODE
==
NVRTCMode
.
ConstantMemory
:
custom_names
=
[
f
"&
{
kernel
.
namespace
}
::
{
NVRTCConstants
.
CONSTANT_PARAM_KEY
}
"
]
custom_names
=
[
f
"&
{
kernel
.
namespace
}
::
{
NVRTCConstants
.
CONSTANT_PARAM_KEY
}
"
]
cudadevrt
=
""
if
SPCONV_NVRTC_MODE
==
NVRTCMode
.
DynamicParallism
:
cudadevrt_p
=
get_cudadevrt_path
()
...
...
@@ -735,10 +840,12 @@ class SimpleConv:
mod
.
load
()
return
mod
,
kernel
def
_cached_get_nvrtc_params
(
self
,
desp
:
ConvAlgoDesp
,
arch
:
Tuple
[
int
,
int
]):
def
_cached_get_nvrtc_params
(
self
,
desp
:
ConvAlgoDesp
,
arch
:
Tuple
[
int
,
int
]):
key
=
(
str
(
desp
),
arch
)
if
key
in
self
.
_nvrtc_caches
:
return
self
.
_nvrtc_caches
[
key
]
print
(
f
"Can't find algo
{
desp
}
in prebuilt. compile with nvrtc..."
)
mod
,
ker
=
self
.
_compile_nvrtc_module
(
desp
)
nvrtc_params
=
_get_nvrtc_params
(
mod
,
ker
,
"conv_kernel"
)
self
.
_nvrtc_caches
[
key
]
=
nvrtc_params
...
...
@@ -795,8 +902,8 @@ class SimpleConv:
params
.
indices
=
indices
params
.
mask
=
mask
params
.
mask_output
=
mask_output
if
op_type
==
ConvOpType
.
kBackwardWeight
:
assert
not
mask_output
.
empty
()
#
if op_type == ConvOpType.kBackwardWeight:
#
assert not mask_output.empty()
if
op_type
==
ConvOpType
.
kBackwardInput
:
params
.
reverse_mask
=
reverse_mask
params
.
mask_filter
=
mask_filter
...
...
@@ -808,20 +915,20 @@ class SimpleConv:
spk_speeds
=
[]
for
spk
in
splitk_tests
:
this_times
=
[]
for
j
in
range
(
3
):
GemmMainUnitTest
.
stream_synchronize
(
stream
)
t
=
time
.
time
()
for
j
in
range
(
4
):
params
.
split_k_slices
=
spk
if
desp
.
is_nvrtc
and
str
(
desp
)
not
in
self
.
prebuilt_desp_names
:
with
tv
.
measure_duration
(
stream
=
stream
)
as
measure
:
if
desp
.
is_nvrtc
and
str
(
desp
)
not
in
self
.
prebuilt_desp_names
:
tv
.
gemm
.
run_nvrtc_conv_kernel
(
params
)
else
:
ConvMainUnitTest
.
implicit_gemm2
(
params
)
GemmMainUnitTest
.
stream_synchronize
(
stream
)
this_times
.
append
(
time
.
time
()
-
t
)
this_times
.
append
(
measure
.
duration
)
times
.
append
(
np
.
mean
(
this_times
[
1
:]))
spk_speeds
.
append
(
times
[
-
1
])
all_profile_res
.
append
(
BestConvAlgoByProfile
(
desp
,
arch
,
splitk
=
spk
))
all_profile_res
.
append
(
BestConvAlgoByProfile
(
desp
,
arch
,
splitk
=
spk
))
if
not
all_profile_res
:
raise
ValueError
(
"can't find suitable algorithm for"
,
op_type
)
min_time
=
1000
...
...
@@ -865,7 +972,8 @@ class SimpleConv:
stream
:
int
=
0
,
workspace
:
tv
.
Tensor
=
tv
.
Tensor
(),
verbose
:
bool
=
False
,
timer
:
CUDAKernelTimer
=
CUDAKernelTimer
(
False
)):
timer
:
CUDAKernelTimer
=
CUDAKernelTimer
(
False
),
force_nvrtc
:
bool
=
False
):
channel_k
=
output
.
dim
(
1
)
channel_c
=
inp
.
dim
(
1
)
# GemmMainUnitTest.stream_synchronize(stream)
...
...
@@ -879,13 +987,17 @@ class SimpleConv:
else
:
op_type_value
=
op_type
.
value
params
=
ConvParams
(
NDIM_DONT_CARE
,
ConvOpTypeCpp
(
op_type_value
))
if
algo_desp
.
is_nvrtc
and
str
(
algo_desp
)
not
in
self
.
prebuilt_desp_names
:
params
.
nvrtc_params
=
self
.
_cached_get_nvrtc_params
(
algo_desp
,
profile_res
.
arch
)
is_not_static
=
str
(
algo_desp
)
not
in
self
.
prebuilt_desp_names
if
algo_desp
.
is_nvrtc
and
(
is_not_static
or
force_nvrtc
):
params
.
nvrtc_params
=
self
.
_cached_get_nvrtc_params
(
algo_desp
,
profile_res
.
arch
)
params
.
conv_algo_desp
=
profile_res
.
algo_desp
params
.
input
=
inp
params
.
verbose
=
verbose
params
.
weight
=
weight
.
view
([
channel_k
,
-
1
,
channel_c
])
params
.
output
=
output
params
.
split_k_slices
=
split_k_slices
params
.
alpha
=
alpha
params
.
beta
=
beta
...
...
@@ -893,6 +1005,7 @@ class SimpleConv:
params
.
mask_argsort
=
mask_argsort
params
.
indices
=
indices
params
.
mask
=
mask
params
.
mask_filter
=
mask_filter
params
.
mask_width
=
mask_width
params
.
mask_filter
=
mask_filter
...
...
@@ -919,6 +1032,13 @@ class SimpleConv:
GEMM
=
SimpleGemm
(
ALL_ALGO_DESPS
)
CONV
=
SimpleConv
(
ALL_CONV_ALGO_DESPS
)
GEMM_CPP
=
GemmTunerSimple
([
algocore
.
get_gemm_algo_desp_from_param
(
p
)
for
p
in
ALL_NATIVE_PARAMS
])
CONV_CPP
=
ConvTunerSimple
([
algocore
.
get_conv_algo_desp_from_param
(
p
)
for
p
in
ALL_IMPGEMM_PARAMS
])
if
__name__
==
"__main__"
:
print
(
len
(
ALL_CONV_ALGO_DESPS
))
print
(
ALL_CONV_ALGO_DESPS
[
0
])
spconv/algocore.py
View file @
899008fa
...
...
@@ -24,8 +24,8 @@ from cumm.tensorview.gemm import ConvLayoutType as ConvLayoutTypeCpp
from
cumm.tensorview.gemm
import
ShuffleStrideType
as
ShuffleStrideTypeCpp
from
cumm.tensorview.gemm
import
ConvParams
,
GemmAlgoDesp
,
GemmParams
from
cumm.gemm.main
import
GemmAlgoParams
from
cumm.conv.main
import
ConvAlgoParams
,
ConvIterAlgo
from
cumm.gemm.main
import
GemmAlgoParams
,
gen_gemm_kernels
from
cumm.conv.main
import
ConvAlgoParams
,
ConvIterAlgo
,
gen_gemm_kernels
as
gen_conv_kernels
from
cumm
import
dtypes
from
cumm.conv.bases
import
(
NCHW
,
NHWC
,
ConvIterAlgo
,
ConvLayout
,
ConvLayoutType
,
ConvMode
,
ConvOpType
)
...
...
@@ -56,10 +56,15 @@ def _assign_gemm_desp_props(desp: Union[ConvAlgoDesp, GemmAlgoDesp],
desp
.
access_per_vector
=
p
.
access_per_vector
desp
.
is_nvrtc
=
p
.
is_nvrtc
def
get_gemm_algo_desp_from_param
(
p
:
GemmAlgoParams
):
desp
=
GemmAlgoDesp
()
_assign_gemm_desp_props
(
desp
,
p
)
# here we must generate kernel for element-per-access data
ker
=
gen_gemm_kernels
(
p
)
desp
.
element_per_access_a
=
ker
.
input_spec
.
input_iter_a
.
element_per_acc
desp
.
element_per_access_b
=
ker
.
input_spec
.
input_iter_b
.
element_per_acc
desp
.
element_per_access_c
=
ker
.
output_spec
.
out_iter
.
element_per_acc
return
desp
...
...
@@ -78,6 +83,10 @@ def get_conv_algo_desp_from_param(p: ConvAlgoParams):
desp
.
interleave_o
=
p
.
layout_desp_output
.
interleave
desp
.
mask_sparse
=
p
.
mask_sparse
desp
.
increment_k_first
=
p
.
increment_k_first
ker
=
gen_conv_kernels
(
p
)
desp
.
element_per_access_a
=
ker
.
input_spec
.
input_iter_a
.
element_per_acc
desp
.
element_per_access_b
=
ker
.
input_spec
.
input_iter_b
.
element_per_acc
desp
.
element_per_access_c
=
ker
.
output_spec
.
out_iter
.
element_per_acc
return
desp
...
...
@@ -106,6 +115,7 @@ def _assign_gemm_params(desp: Union[ConvAlgoDesp, GemmAlgoDesp],
p
.
is_nvrtc
=
desp
.
is_nvrtc
def
get_gemm_param_from_desp
(
desp
:
GemmAlgoDesp
):
p
=
GemmAlgoParams
((
0
,
0
,
0
),
(
0
,
0
,
0
),
0
,
"s8,s8,s8,s8,s8"
,
False
,
False
,
False
,
GemmAlgo
.
Simt
)
...
...
spconv/benchmark/me.py
deleted
100644 → 0
View file @
f78575ea
"""Benchmark MinkowskiEngine
"""
from
spconv.benchmark.core
import
get_voxel_data
import
time
from
pathlib
import
Path
import
numpy
as
np
import
torch
from
torch
import
nn
from
spconv.core
import
ConvAlgo
from
cumm
import
dtypes
from
spconv.test_utils
import
params_grid
_DTYPE_TO_TORCH_DTYPE
=
{
dtypes
.
float32
:
torch
.
float32
,
dtypes
.
float16
:
torch
.
float16
,
}
def
bench_me_basic
(
dtype_str
:
str
):
dtype
=
dtypes
.
get_dtype_by_shortcut
(
dtype_str
)
if
dtype
not
in
_DTYPE_TO_TORCH_DTYPE
:
raise
NotImplementedError
(
"only support bench f32 and f16 for now"
)
torch_dtype
=
_DTYPE_TO_TORCH_DTYPE
[
dtype
]
spconv/benchmark/thsp.py
deleted
100644 → 0
View file @
f78575ea
"""Benchmark torchsparse
"""
from
spconv.benchmark.core
import
get_voxel_data
import
time
from
pathlib
import
Path
import
numpy
as
np
import
torch
from
torch
import
nn
from
spconv.core
import
ConvAlgo
from
cumm
import
dtypes
from
spconv.test_utils
import
params_grid
_DTYPE_TO_TORCH_DTYPE
=
{
dtypes
.
float32
:
torch
.
float32
,
dtypes
.
float16
:
torch
.
float16
,
}
def
bench_torchsparse_basic
(
dtype_str
:
str
):
dtype
=
dtypes
.
get_dtype_by_shortcut
(
dtype_str
)
if
dtype
not
in
_DTYPE_TO_TORCH_DTYPE
:
raise
NotImplementedError
(
"only support bench f32 and f16 for now"
)
torch_dtype
=
_DTYPE_TO_TORCH_DTYPE
[
dtype
]
spconv/build.py
View file @
899008fa
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
from
pathlib
import
Path
from
typing
import
List
import
pccm
from
pccm.utils
import
project_is_editable
,
project_is_installed
...
...
@@ -32,6 +33,10 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable(
from
spconv.csrc.sparse.alloc
import
ExternalAllocator
from
spconv.csrc.utils
import
BoxOps
from
spconv.csrc.hash.core
import
HashTable
from
spconv.csrc.sparse.convops
import
GemmTunerSimple
,
ExternalSpconvMatmul
from
spconv.csrc.sparse.convops
import
ConvTunerSimple
,
ConvGemmOps
from
spconv.csrc.sparse.convops
import
SimpleExternalSpconvMatmul
all_shuffle
=
SHUFFLE_SIMT_PARAMS
+
SHUFFLE_VOLTA_PARAMS
+
SHUFFLE_TURING_PARAMS
all_shuffle
=
list
(
filter
(
lambda
x
:
not
x
.
is_nvrtc
,
all_shuffle
))
cu
=
GemmMainUnitTest
(
all_shuffle
)
...
...
@@ -41,8 +46,35 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable(
all_imp
=
list
(
filter
(
lambda
x
:
not
x
.
is_nvrtc
,
all_imp
))
convcu
=
ConvMainUnitTest
(
all_imp
)
convcu
.
namespace
=
"cumm.conv.main"
pccm
.
builder
.
build_pybind
([
cu
,
convcu
,
SpconvOps
(),
BoxOps
(),
HashTable
(),
CompileInfo
(),
ExternalAllocator
()],
gemmtuner
=
GemmTunerSimple
(
cu
)
gemmtuner
.
namespace
=
"csrc.sparse.convops.gemmops"
convtuner
=
ConvTunerSimple
(
convcu
)
convtuner
.
namespace
=
"csrc.sparse.convops.convops"
convops
=
ConvGemmOps
(
gemmtuner
,
convtuner
)
convops
.
namespace
=
"csrc.sparse.convops.spops"
cus
=
[
cu
,
convcu
,
gemmtuner
,
convtuner
,
convops
,
SpconvOps
(),
BoxOps
(),
HashTable
(),
CompileInfo
(),
ExternalAllocator
(),
ExternalSpconvMatmul
(),
SimpleExternalSpconvMatmul
(),
]
pccm
.
builder
.
build_pybind
(
cus
,
PACKAGE_ROOT
/
"core_cc"
,
namespace_root
=
PACKAGE_ROOT
,
load_library
=
False
)
load_library
=
False
,
verbose
=
True
)
# cus_dev: List[pccm.Class] = [
# ]
# pccm.builder.build_pybind(cus_dev,
# PACKAGE_ROOT / "core_cc_dev",
# namespace_root=PACKAGE_ROOT,
# load_library=False,
# verbose=True)
spconv/constants.py
View file @
899008fa
...
...
@@ -30,6 +30,7 @@ if _filter_hwio_env is not None:
raise
NotImplementedError
(
"SPCONV_FILTER_HWIO is deprecated. use SPCONV_SAVED_WEIGHT_LAYOUT instead."
)
DISABLE_JIT
=
os
.
getenv
(
"SPCONV_DISABLE_JIT"
,
"0"
)
==
"1"
NDIM_DONT_CARE
=
3
FILTER_HWIO
=
False
...
...
@@ -59,8 +60,10 @@ SPCONV_BWD_SPLITK = list(map(int, os.getenv("SPCONV_BWD_SPLITK", "1,2,4,8,16,32,
SPCONV_NVRTC_MODE
=
NVRTCMode
.
ConstantMemory
SPCONV_DEBUG_NVRTC_KERNELS
=
False
SPCONV_DEBUG_CPP_ONLY
=
project_is_editable
(
PACKAGE_NAME
)
class
Spconv
Alloc
ator
Keys
:
class
AllocKeys
:
Pair
=
"Pair"
IndiceNumPerLoc
=
"IndiceNumPerLoc"
PairMask
=
"PairMask"
...
...
@@ -72,5 +75,31 @@ class SpconvAllocatorKeys:
# MaskArgSortFwd = "MaskArgSortFwd"
MaskArgSortBwd
=
"MaskArgSortBwd"
MaskOutputFwd
=
"MaskOutputFwd"
OutFeatures
=
"OutFeatures"
Features
=
"Features"
Filters
=
"Filters"
OutBp
=
"OutBp"
DIn
=
"DIn"
DFilters
=
"DFilters"
InpBuffer
=
"InpBuffer"
OutBuffer
=
"OutBuffer"
IndicePairsUniq
=
"IndicePairsUniq"
IndicePairsUniqBackup
=
"IndicePairsUniqBackup"
HashKOrKV
=
"HashKOrKV"
HashV
=
"HashV"
ThrustTemp
=
"ThrustTemp"
SPCONV_DEBUG_WEIGHT
=
False
SPCONV_CPP_INDICE_PAIRS
=
True
SPCONV_CPP_INDICE_PAIRS_IGEMM
=
True
SPCONV_CPP_GEMM
=
True
\ No newline at end of file
spconv/core.py
View file @
899008fa
...
...
@@ -16,9 +16,10 @@ from cumm.gemm.main import gen_shuffle_params_v2 as gen_shuffle_params, GemmAlgo
from
cumm.gemm
import
kernel
from
typing
import
List
from
cumm.gemm.algospec.core
import
TensorOp
from
cumm.conv.main
import
gen_gemm_params
as
gen_conv_params
,
ConvFwdAndBwdInput
,
ConvBwdWeight
,
ConvIterAlgo
,
GemmAlgo
from
cumm.conv.main
import
gen_gemm_params
as
gen_conv_params
,
ConvFwdAndBwdInput
,
ConvBwdWeight
,
ConvFwd
,
ConvIterAlgo
,
GemmAlgo
from
cumm.conv.bases
import
(
NCHW
,
NHWC
,
ConvIterAlgo
,
ConvLayout
,
ConvLayoutType
,
ConvMode
,
ConvOpType
)
from
spconv.algocore
import
get_gemm_algo_desp_from_param
from
spconv.constants
import
NDIM_DONT_CARE
...
...
@@ -402,32 +403,6 @@ IMPLGEMM_SIMT_PARAMS = [
increment_k_first
=
True
,
access_per_vector
=
1
),
]
IMPLGEMM_SIMT_PARAMS
=
[
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
32
,
16
),
(
32
,
32
,
8
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"f32,f32,f32,f32,f32"
,
"f16,f16,f16,f32,f32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Simt
,
None
,
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
*
gen_conv_params
(
ConvBwdWeight
,
(
64
,
32
,
16
),
(
32
,
32
,
8
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"f32,f32,f32,f32,f32"
,
"f16,f16,f16,f32,f32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Simt
,
None
,
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
]
IMPLGEMM_VOLTA_PARAMS
=
[
...
...
@@ -693,6 +668,181 @@ IMPLGEMM_TURING_PARAMS = [
# NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOp((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1),
# gen_conv_params(ConvFwdAndBwdInput, )
# all int8 kernels use nvrtc.
*
gen_conv_params
(
ConvFwd
,
(
32
,
32
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
*
gen_conv_params
(
ConvFwd
,
(
32
,
64
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
*
gen_conv_params
(
ConvFwd
,
(
32
,
32
,
64
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
*
gen_conv_params
(
ConvFwd
,
(
32
,
64
,
64
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
*
gen_conv_params
(
ConvFwd
,
(
64
,
128
,
32
),
(
32
,
64
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
*
gen_conv_params
(
ConvFwd
,
(
64
,
64
,
32
),
(
32
,
64
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
*
gen_conv_params
(
ConvFwd
,
(
64
,
64
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
*
gen_conv_params
(
ConvFwd
,
(
64
,
32
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
*
gen_conv_params
(
ConvFwd
,
(
128
,
128
,
64
),
(
64
,
64
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
*
gen_conv_params
(
ConvFwd
,
(
64
,
128
,
64
),
(
32
,
64
,
64
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"s8,s8,s8,s32,s32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
)),
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
,
is_nvrtc
=
False
),
# *gen_conv_params(ConvFwd, (32, 32, 32), (32, 32, 32),
# NDIM_DONT_CARE,
# ConvIterAlgo.Optimized,
# 2, ["s8,s8,s8,s32,s32"],
# NHWC,
# NHWC,
# NHWC,
# GemmAlgo.Turing,
# TensorOp((8, 8, 16)),
# mask_sparse=True,
# increment_k_first=True,
# access_per_vector=0,
# is_nvrtc=True),
# *gen_conv_params(ConvFwd, (32, 64, 32), (32, 32, 32),
# NDIM_DONT_CARE,
# ConvIterAlgo.Optimized,
# 2, ["s8,s8,s8,s32,s32"],
# NHWC,
# NHWC,
# NHWC,
# GemmAlgo.Turing,
# TensorOp((8, 8, 16)),
# mask_sparse=True,
# increment_k_first=True,
# access_per_vector=0,
# is_nvrtc=True),
# *gen_conv_params(ConvFwd, (32, 32, 64), (32, 32, 32),
# NDIM_DONT_CARE,
# ConvIterAlgo.Optimized,
# 2, ["s8,s8,s8,s32,s32"],
# NHWC,
# NHWC,
# NHWC,
# GemmAlgo.Turing,
# TensorOp((8, 8, 16)),
# mask_sparse=True,
# increment_k_first=True,
# access_per_vector=0,
# is_nvrtc=True),
]
ALL_NATIVE_PARAMS
=
SHUFFLE_SIMT_PARAMS
+
SHUFFLE_TURING_PARAMS
+
SHUFFLE_VOLTA_PARAMS
...
...
spconv/core_cc/csrc/sparse/all/__init__.pyi
View file @
899008fa
...
...
@@ -48,7 +48,7 @@ class SpconvOps:
"""
...
@staticmethod
def generate_conv_inds_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
def generate_conv_inds_stage2(indices: Tensor, hashdata_k: Tensor, hashdata_v: Tensor, indice_pairs: Tensor, indice_pairs_uniq: Tensor, indice_pairs_uniq_before_sort: Tensor, out_inds: Tensor,
indice_num_per_loc: Tensor,
num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0
, use_bound_algo: bool = False
) -> int:
"""
Args:
indices:
...
...
@@ -58,6 +58,7 @@ class SpconvOps:
indice_pairs_uniq:
indice_pairs_uniq_before_sort:
out_inds:
indice_num_per_loc:
num_out_act:
batch_size:
output_dims:
...
...
@@ -68,6 +69,7 @@ class SpconvOps:
dilation:
transposed:
stream_int:
use_bound_algo:
"""
...
@staticmethod
...
...
@@ -191,6 +193,31 @@ class SpconvOps:
"""
...
@staticmethod
def indice_maxpool(out_features: Tensor, features: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor, num_activate_out: int, stream: int = 0) -> None:
"""
Args:
out_features:
features:
indice_pairs:
indice_pair_num:
num_activate_out:
stream:
"""
...
@staticmethod
def indice_maxpool_backward(din: Tensor, features: Tensor, out_features: Tensor, out_bp: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor, stream: int = 0) -> None:
"""
Args:
din:
features:
out_features:
out_bp:
indice_pairs:
indice_pair_num:
stream:
"""
...
@staticmethod
def maxpool_implicit_gemm_forward(out: Tensor, inp: Tensor, inds: Tensor, stream: int = 0) -> None:
"""
Args:
...
...
@@ -369,7 +396,18 @@ class SpconvOps:
@staticmethod
def get_int32_max() -> int: ...
@staticmethod
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0) -> Tensor:
def get_indice_gen_workspace_size(kv: int, num_act_in: int, num_act_out_bound: int, subm: bool, use_int64_hash_k: bool) -> int:
"""
Args:
kv:
num_act_in:
num_act_out_bound:
subm:
use_int64_hash_k:
"""
...
@staticmethod
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0, num_out_act_bound: int = -1) -> Tuple[Tensor, int]:
"""
Args:
allocator:
...
...
@@ -386,10 +424,11 @@ class SpconvOps:
transposed:
is_train:
stream_int:
num_out_act_bound:
"""
...
@staticmethod
def get_indice_pairs(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, stream_int: int = 0) ->
None
:
def get_indice_pairs(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, stream_int: int = 0
, num_out_act_bound: int = -1
) ->
int
:
"""
Args:
allocator:
...
...
@@ -405,12 +444,6 @@ class SpconvOps:
subm:
transposed:
stream_int:
"""
...
@staticmethod
def test_allocator(allocator) -> None:
"""
Args:
allocator:
num_out_act_bound:
"""
...
spconv/core_cc/csrc/sparse/alloc.pyi
View file @
899008fa
...
...
@@ -2,25 +2,29 @@ from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Ty
from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview import Tensor
class ExternalAllocator:
def zeros(self, name: str, shape: List[int], dtype: int, device: int) -> Tensor:
def zeros(self, name: str, shape: List[int], dtype: int, device: int
, is_temp_memory: bool = False, stream: int = 0
) -> Tensor:
"""
Args:
name:
shape:
dtype:
device:
is_temp_memory:
stream:
"""
...
def empty(self, name: str, shape: List[int], dtype: int, device: int) -> Tensor:
def empty(self, name: str, shape: List[int], dtype: int, device: int
, is_temp_memory: bool = False, stream: int = 0
) -> Tensor:
"""
Args:
name:
shape:
dtype:
device:
is_temp_memory:
stream:
"""
...
def full_int(self, name: str, shape: List[int], value: int, dtype: int, device: int) -> Tensor:
def full_int(self, name: str, shape: List[int], value: int, dtype: int, device: int
, is_temp_memory: bool = False, stream: int = 0
) -> Tensor:
"""
Args:
name:
...
...
@@ -28,9 +32,11 @@ class ExternalAllocator:
value:
dtype:
device:
is_temp_memory:
stream:
"""
...
def full_float(self, name: str, shape: List[int], value: float, dtype: int, device: int) -> Tensor:
def full_float(self, name: str, shape: List[int], value: float, dtype: int, device: int
, is_temp_memory: bool = False, stream: int = 0
) -> Tensor:
"""
Args:
name:
...
...
@@ -38,6 +44,14 @@ class ExternalAllocator:
value:
dtype:
device:
is_temp_memory:
stream:
"""
...
def get_tensor_by_name(self, name: str) -> Tensor:
"""
Args:
name:
"""
...
def free(self, ten: Tensor) -> None:
...
...
spconv/core_cc/csrc/sparse/convops/__init__.pyi
0 → 100644
View file @
899008fa
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview.gemm import GemmAlgoDesp
from cumm.tensorview.gemm import ConvAlgoDesp
from cumm.tensorview import Tensor
from ...csrc.sparse.convops import ExternalSpconvMatmul
class GemmTuneResult:
algo_desp: GemmAlgoDesp
arch: Tuple[int, int]
splitk: int
def is_valid(self) -> bool: ...
@overload
def __init__(self) -> None: ...
@overload
def __init__(self, algo_desp: GemmAlgoDesp, arch: Tuple[int, int], splitk: int) -> None:
"""
Args:
algo_desp:
arch:
splitk:
"""
...
class ConvTuneResult:
algo_desp: ConvAlgoDesp
arch: Tuple[int, int]
splitk: int
@overload
def __init__(self) -> None: ...
@overload
def __init__(self, algo_desp: ConvAlgoDesp, arch: Tuple[int, int], splitk: int) -> None:
"""
Args:
algo_desp:
arch:
splitk:
"""
...
def is_valid(self) -> bool: ...
class ExternalSpconvMatmul:
def indice_conv_init_gemm(self, features_n: str, filters_n: str, all_weight_is_krsc: bool, is_kc_not_ck: bool, kv_center: int, out_channel: int, stream_int: int = 0) -> Tensor:
"""
Args:
features_n:
filters_n:
all_weight_is_krsc:
is_kc_not_ck:
kv_center:
out_channel:
stream_int:
"""
...
def indice_conv_cpu_gemm(self, inp_buffer_n: str, out_buffer_n: str, filters_n: str, all_weight_is_krsc: bool, is_kc_not_ck: bool, nhot: int, index: int) -> None:
"""
Args:
inp_buffer_n:
out_buffer_n:
filters_n:
all_weight_is_krsc:
is_kc_not_ck:
nhot:
index:
"""
...
def indice_conv_bwd_init_gemm(self, features_n: str, filters_n: str, out_bp_n: str, dfilters_n: str, all_weight_is_krsc: bool, is_kc_not_ck: bool, kv_center: int, stream_int: int = 0) -> Tensor:
"""
Args:
features_n:
filters_n:
out_bp_n:
dfilters_n:
all_weight_is_krsc:
is_kc_not_ck:
kv_center:
stream_int:
"""
...
def indice_conv_bwd_cpu_gemm(self, inp_buffer_n: str, out_buffer_n: str, filters_n: str, dfilters_n: str, all_weight_is_krsc: bool, is_kc_not_ck: bool, nhot: int, index: int) -> None:
"""
Args:
inp_buffer_n:
out_buffer_n:
filters_n:
dfilters_n:
all_weight_is_krsc:
is_kc_not_ck:
nhot:
index:
"""
...
class SimpleExternalSpconvMatmul(ExternalSpconvMatmul):
def __init__(self, alloc) -> None:
"""
Args:
alloc:
"""
...
spconv/core_cc/csrc/sparse/convops/convops.pyi
0 → 100644
View file @
899008fa
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview.gemm import ConvAlgoDesp
from cumm.tensorview import Tensor
from cumm.tensorview.gemm import NVRTCParams
from spconv.core_cc.csrc.sparse.convops import ConvTuneResult
from cumm.tensorview import CUDAKernelTimer
class ConvTunerSimple:
def __init__(self, desps: List[ConvAlgoDesp]) -> None:
"""
Args:
desps:
"""
...
@staticmethod
def get_available_algo_str_from_arch(arch: Tuple[int, int]) -> List[str]:
"""
Args:
arch:
"""
...
def get_all_available(self, inp: Tensor, weight: Tensor, out: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], op_type: int, mask_width: int, auto_fp32_accum: bool, fp32_accum: bool) -> List[ConvAlgoDesp]:
"""
Args:
inp:
weight:
out:
layout_i:
layout_w:
layout_o:
interleave_i:
interleave_w:
interleave_o:
arch:
op_type:
mask_width:
auto_fp32_accum:
fp32_accum:
"""
...
def cached_get_nvrtc_params(self, desp: ConvAlgoDesp, arch: Tuple[int, int], stream_int: int) -> NVRTCParams:
"""
Args:
desp:
arch:
stream_int:
"""
...
def tune_and_cache(self, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, layout_i: int, layout_w: int, layout_o: int, interleave_i: int, interleave_w: int, interleave_o: int, arch: Tuple[int, int], mask: Tensor, mask_argsort: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, mask_output: Tensor = Tensor(), alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, auto_fp32_accum: bool = True, fp32_accum: bool = False, num_run: int = 5) -> Tuple[ConvTuneResult, float]:
"""
Args:
op_type:
inp:
weight:
output:
layout_i:
layout_w:
layout_o:
interleave_i:
interleave_w:
interleave_o:
arch:
mask:
mask_argsort:
indices:
reverse_mask:
mask_filter:
mask_width:
mask_output:
alpha:
beta:
stream_int:
auto_fp32_accum:
fp32_accum:
num_run:
"""
...
def get_tuned_algo(self, op_type: int, i_dtype: int, w_dtype: int, o_dtype: int, k: int, c: int, arch: Tuple[int, int], mask_width: int = -1) -> Tuple[Any, bool]:
"""
Args:
op_type:
i_dtype:
w_dtype:
o_dtype:
k:
c:
arch:
mask_width:
"""
...
def run_with_tuned_result(self, profile_res, op_type: int, inp: Tensor, weight: Tensor, output: Tensor, mask: Tensor, mask_argsort: Tensor, mask_output: Tensor, indices: Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, workspace: Tensor = Tensor(), verbose: bool = False, timer: CUDAKernelTimer = CUDAKernelTimer(false), force_nvrtc: bool = False) -> None:
"""
Args:
profile_res:
op_type:
inp:
weight:
output:
mask:
mask_argsort:
mask_output:
indices:
reverse_mask:
mask_filter:
mask_width:
alpha:
beta:
stream_int:
workspace:
verbose:
timer:
force_nvrtc:
"""
...
def query_workspace_size(self, desp: ConvAlgoDesp, splitk: int, op_type: int, N: int, C: int, K: int, kv: int) -> int:
"""
Args:
desp:
splitk:
op_type:
N:
C:
K:
kv:
"""
...
spconv/core_cc/csrc/sparse/convops/gemmops.pyi
0 → 100644
View file @
899008fa
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview.gemm import GemmAlgoDesp
from cumm.tensorview import Tensor
from cumm.tensorview.gemm import NVRTCParams
from spconv.core_cc.csrc.sparse.convops import GemmTuneResult
from cumm.tensorview import CUDAKernelTimer
class GemmTunerSimple:
def __init__(self, desps: List[GemmAlgoDesp]) -> None:
"""
Args:
desps:
"""
...
@staticmethod
def get_available_algo_str_from_arch(arch: Tuple[int, int]) -> List[str]:
"""
Args:
arch:
"""
...
def get_all_available(self, a: Tensor, b: Tensor, c: Tensor, trans_a: bool, trans_b: bool, trans_c: bool, arch: Tuple[int, int], shuffle_type: int) -> List[GemmAlgoDesp]:
"""
Args:
a:
b:
c:
trans_a:
trans_b:
trans_c:
arch:
shuffle_type:
"""
...
def cached_get_nvrtc_params(self, desp: GemmAlgoDesp, arch: Tuple[int, int], stream_int: int) -> NVRTCParams:
"""
Args:
desp:
arch:
stream_int:
"""
...
def tune_and_cache(self, a: Tensor, b: Tensor, c: Tensor, trans_a: bool, trans_b: bool, trans_c: bool, arch: Tuple[int, int], shuffle_type: int, a_inds: Tensor, b_inds: Tensor, c_inds: Tensor, hint: int = 0, alpha: float = 1.0, beta: float = 0.0, stream_int: int = 0, num_run: int = 5) -> Tuple[GemmTuneResult, float]:
"""
Args:
a:
b:
c:
trans_a:
trans_b:
trans_c:
arch:
shuffle_type:
a_inds:
b_inds:
c_inds:
hint:
alpha:
beta:
stream_int:
num_run:
"""
...
def get_tuned_algo(self, a_dtype: int, b_dtype: int, c_dtype: int, a_shape: List[int], b_shape: List[int], c_shape: List[int], trans_a: bool, trans_b: bool, trans_c: bool, arch: Tuple[int, int], shuffle_type: int, a_inds_shape: List[int], b_inds_shape: List[int], c_inds_shape: List[int], hint: int = 0) -> Tuple[Any, bool]:
"""
Args:
a_dtype:
b_dtype:
c_dtype:
a_shape:
b_shape:
c_shape:
trans_a:
trans_b:
trans_c:
arch:
shuffle_type:
a_inds_shape:
b_inds_shape:
c_inds_shape:
hint:
"""
...
def run_with_tuned_result(self, profile_res, a: Tensor, b: Tensor, c: Tensor, trans_a: bool, trans_b: bool, trans_c: bool, arch: Tuple[int, int], stream_int: int, shuffle_type: int, a_inds: Tensor, b_inds: Tensor, c_inds: Tensor, hint: int = 0, alpha: float = 1.0, beta: float = 0.0, workspace: Tensor = Tensor(), timer: CUDAKernelTimer = CUDAKernelTimer(False), force_nvrtc: bool = False) -> None:
"""
Args:
profile_res:
a:
b:
c:
trans_a:
trans_b:
trans_c:
arch:
stream_int:
shuffle_type:
a_inds:
b_inds:
c_inds:
hint:
alpha:
beta:
workspace:
timer:
force_nvrtc:
"""
...
spconv/core_cc/csrc/sparse/convops/spops.pyi
0 → 100644
View file @
899008fa
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview import Tensor
from cumm.tensorview import CUDAKernelTimer
class ConvGemmOps:
@staticmethod
def get_compute_capability(index: int = -1) -> Tuple[int, int]:
"""
Args:
index:
"""
...
@staticmethod
def indice_conv(allocator, ext_mm, gemm_tuner, all_w_is_krsc: bool, filter_hwio: bool, features: Tensor, filters: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor, num_activate_out: int, inverse: bool = False, subm: bool = False, algo: int = 0, stream_int: int = 0) -> None:
"""
1. this function need to take a out features
that from subm first mm.
2. this function don't support CPU.
Args:
allocator:
ext_mm:
gemm_tuner:
all_w_is_krsc:
filter_hwio:
features:
filters:
indice_pairs:
indice_pair_num:
num_activate_out:
inverse:
subm:
algo:
stream_int:
"""
...
@staticmethod
def indice_conv_backward(allocator, ext_mm, gemm_tuner, all_w_is_krsc: bool, filter_hwio: bool, features: Tensor, filters: Tensor, out_bp: Tensor, indice_pairs: Tensor, indice_pair_num: Tensor, inverse: bool = False, subm: bool = False, algo: int = 0, stream_int: int = 0) -> None:
"""
Args:
allocator:
ext_mm:
gemm_tuner:
all_w_is_krsc:
filter_hwio:
features:
filters:
out_bp:
indice_pairs:
indice_pair_num:
inverse:
subm:
algo:
stream_int:
"""
...
@staticmethod
def implicit_gemm(allocator, conv_tuner, features: Tensor, filters: Tensor, pair_fwd: Tensor, pair_mask_fwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], num_activate_out: int, masks: Tensor, is_train: bool = False, is_subm: bool = False, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False) -> int:
"""
Args:
allocator:
conv_tuner:
features:
filters:
pair_fwd:
pair_mask_fwd_splits:
mask_argsort_fwd_splits:
num_activate_out:
masks:
is_train:
is_subm:
stream_int:
timer:
auto_fp32_accum:
fp32_accum:
"""
...
@staticmethod
def implicit_gemm_backward(allocator, conv_tuner, features: Tensor, filters: Tensor, out_bp: Tensor, pair_fwd: Tensor, pair_bwd: Tensor, pair_mask_fwd_splits: List[Tensor], pair_mask_bwd_splits: List[Tensor], mask_argsort_fwd_splits: List[Tensor], mask_argsort_bwd_splits: List[Tensor], mask_output_fwd: Tensor, masks: Tensor, mask_width: int, is_subm: bool, stream_int: int = 0, timer: CUDAKernelTimer = CUDAKernelTimer(False), auto_fp32_accum: bool = True, fp32_accum: bool = False) -> None:
"""
Args:
allocator:
conv_tuner:
features:
filters:
out_bp:
pair_fwd:
pair_bwd:
pair_mask_fwd_splits:
pair_mask_bwd_splits:
mask_argsort_fwd_splits:
mask_argsort_bwd_splits:
mask_output_fwd:
masks:
mask_width:
is_subm:
stream_int:
timer:
auto_fp32_accum:
fp32_accum:
"""
...
spconv/core_cc/cumm/common.pyi
View file @
899008fa
...
...
@@ -3,3 +3,10 @@ from pccm.stubs import EnumValue, EnumClassValue
class CompileInfo:
@staticmethod
def get_compiled_cuda_arch() -> List[Tuple[int, int]]: ...
@staticmethod
def arch_is_compiled(arch: Tuple[int, int]) -> bool:
"""
Args:
arch:
"""
...
spconv/core_cc/cumm/gemm/main.pyi
View file @
899008fa
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview.gemm import GemmAlgoDesp
from cumm.tensorview.gemm import GemmParams
class GemmMainUnitTest:
@staticmethod
def get_all_algo_desp() -> List[
Any
]: ...
def get_all_algo_desp() -> List[
GemmAlgoDesp
]: ...
@staticmethod
def extract_mnk(a_shape: List[int], b_shape: List[int], trans_a: bool, trans_b: bool, trans_c: bool, shuffle_type:
str = "0"
, a_inds_shape: List[int] = [], b_inds_shape: List[int] = [], c_inds_shape: List[int] = []) -> Tuple[int, int, int]:
def extract_mnk(a_shape: List[int], b_shape: List[int], trans_a: bool, trans_b: bool, trans_c: bool, shuffle_type:
int = 0
, a_inds_shape: List[int] = [], b_inds_shape: List[int] = [], c_inds_shape: List[int] = []) -> Tuple[int, int, int]:
"""
Args:
a_shape:
...
...
spconv/csrc/sparse/all.py
View file @
899008fa
...
...
@@ -26,7 +26,7 @@ from .indices import SparseConvIndicesKernel, CudaCommonKernel, SparseConvIndice
from
.maxpool
import
IndiceMaxPool
,
IndiceMaxPoolCPU
from
.gather
import
GatherCPU
from
.alloc
import
ExternalAllocator
,
ThrustAllocator
from
spconv.constants
import
Spconv
Alloc
ator
Keys
from
spconv.constants
import
AllocKeys
class
CustomThrustLib
(
pccm
.
Class
):
def
__init__
(
self
):
...
...
@@ -34,7 +34,7 @@ class CustomThrustLib(pccm.Class):
self
.
add_dependency
(
ThrustLib
)
# https://github.com/NVIDIA/thrust/issues/1401#issuecomment-806403746
if
compat
.
InLinux
:
self
.
build_meta
.
add_public_cflags
(
"nvcc"
,
"-Xcompiler
"
,
"
-fno-gnu-unique"
,
"-Xcompiler
"
,
"
-fvisibility=hidden"
)
self
.
build_meta
.
add_public_cflags
(
"nvcc"
,
"-Xcompiler
-fno-gnu-unique"
,
"-Xcompiler
-fvisibility=hidden"
)
class
ThrustCustomAllocatorV2
(
pccm
.
Class
,
pccm
.
pybind
.
PybindClassMixin
):
...
...
@@ -76,6 +76,7 @@ class SpconvOps(pccm.Class):
super
().
__init__
()
self
.
add_dependency
(
ThrustCustomAllocatorV2
,
ExternalAllocator
,
GemmBasicHost
,
ThrustAllocator
)
self
.
ndims
=
[
1
,
2
,
3
,
4
]
self
.
cuda_common_kernel
=
CudaCommonKernel
()
for
ndim
in
self
.
ndims
:
p2v
=
Point2Voxel
(
dtypes
.
float32
,
ndim
)
p2v_cpu
=
Point2VoxelCPU
(
dtypes
.
float32
,
ndim
)
...
...
@@ -102,6 +103,11 @@ class SpconvOps(pccm.Class):
indices
,
f
"SpconvIndices
{
ndim
}
D"
)
for
name
in
dir
(
AllocKeys
):
if
not
name
.
startswith
(
"__"
):
v
=
getattr
(
AllocKeys
,
name
)
self
.
add_static_const
(
"k"
+
name
,
"auto"
,
f
"tv::make_const_string(
{
pccm
.
literal
(
v
)
}
)"
)
@
pccm
.
pybind
.
mark
@
pccm
.
static_function
def
cumm_version
(
self
):
...
...
@@ -194,12 +200,15 @@ class SpconvOps(pccm.Class):
code
=
pccm
.
FunctionCode
()
code
.
arg
(
"indices, hashdata_k, hashdata_v"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pairs, indice_pairs_uniq, indice_pairs_uniq_before_sort, out_inds"
,
"tv::Tensor"
)
code
.
arg
(
"indice_num_per_loc"
,
"tv::Tensor"
)
code
.
arg
(
"num_out_act"
,
"int"
)
code
.
arg
(
"batch_size"
,
"int"
)
code
.
arg
(
"output_dims, input_dims"
,
f
"std::vector<int>"
)
code
.
arg
(
"ksize, stride, padding, dilation"
,
f
"std::vector<int>"
)
code
.
arg
(
"transposed"
,
f
"bool"
,
"false"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
arg
(
"use_bound_algo"
,
"bool"
,
"false"
)
if
CUMM_CPU_ONLY_BUILD
:
return
code
.
make_invalid
()
...
...
@@ -225,9 +234,11 @@ class SpconvOps(pccm.Class):
}}
return SpconvIndices
{
ndim
}
D::generate_conv_inds_stage2(indices,
hashdata_k, hashdata_v, indice_pairs,
indice_pairs_uniq, indice_pairs_uniq_before_sort, out_inds, num_out_act,
indice_pairs_uniq, indice_pairs_uniq_before_sort, out_inds,
indice_num_per_loc, num_out_act,
batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, transposed, stream_int);
ksize_, stride_, padding_, dilation_, transposed, stream_int,
use_bound_algo);
}}
"""
)
code
.
raw
(
f
"""TV_THROW_RT_ERR("unknown ndim", ndim);"""
)
...
...
@@ -481,6 +492,93 @@ class SpconvOps(pccm.Class):
"""
)
return
code
@
pccm
.
pybind
.
mark
@
pccm
.
cuda
.
static_function
def
indice_maxpool
(
self
):
code
=
pccm
.
FunctionCode
()
code
.
arg
(
"out_features, features"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pairs"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pair_num"
,
"tv::Tensor"
)
code
.
arg
(
"num_activate_out"
,
"int"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
add_dependency
(
IndiceMaxPoolCPU
)
if
not
CUMM_CPU_ONLY_BUILD
:
code
.
add_dependency
(
IndiceMaxPool
)
code
.
raw
(
f
"""
tv::check_shape(out_features, {{-1, features.dim(1)}});
auto indice_pair_num_cpu = indice_pair_num.cpu();
auto indice_pair_num_cpu_ptr = indice_pair_num_cpu.data_ptr<int>();
"""
)
with
code
.
for_
(
"int i = 0; i < indice_pair_num.dim(0); ++i"
):
code
.
raw
(
f
"""
int nhot = indice_pair_num_cpu_ptr[i];
nhot = std::min(nhot, int(indice_pairs.dim(2)));
if (nhot <= 0){{
continue;
}}
auto inp_indices = indice_pairs[0][i].slice_first_axis(0, nhot);
auto out_indices = indice_pairs[1][i].slice_first_axis(0, nhot);
if (features.is_cpu()){{
IndiceMaxPoolCPU::forward(out_features, features, out_indices, inp_indices);
}}
"""
)
if
not
CUMM_CPU_ONLY_BUILD
:
with
code
.
else_
():
code
.
raw
(
f
"""
IndiceMaxPool::forward(out_features, features, out_indices, inp_indices, stream);
"""
)
else
:
code
.
raw
(
f
"""
TV_THROW_RT_ERR("not implemented in cpu-only spconv!!! ")
"""
)
return
code
@
pccm
.
pybind
.
mark
@
pccm
.
cuda
.
static_function
def
indice_maxpool_backward
(
self
):
code
=
pccm
.
FunctionCode
()
code
.
arg
(
"din, features, out_features, out_bp"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pairs"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pair_num"
,
"tv::Tensor"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
add_dependency
(
IndiceMaxPoolCPU
)
if
not
CUMM_CPU_ONLY_BUILD
:
code
.
add_dependency
(
IndiceMaxPool
)
code
.
raw
(
f
"""
tv::check_shape(din, features.shape());
auto indice_pair_num_cpu = indice_pair_num.cpu();
auto indice_pair_num_cpu_ptr = indice_pair_num_cpu.data_ptr<int>();
"""
)
with
code
.
for_
(
"int i = 0; i < indice_pair_num.dim(0); ++i"
):
code
.
raw
(
f
"""
int nhot = indice_pair_num_cpu_ptr[i];
nhot = std::min(nhot, int(indice_pairs.dim(2)));
if (nhot <= 0){{
continue;
}}
auto inp_indices = indice_pairs[0][i].slice_first_axis(0, nhot);
auto out_indices = indice_pairs[1][i].slice_first_axis(0, nhot);
if (features.is_cpu()){{
IndiceMaxPoolCPU::backward(out_features, features, out_bp, din, out_indices, inp_indices);
}}
"""
)
if
not
CUMM_CPU_ONLY_BUILD
:
with
code
.
else_
():
code
.
raw
(
f
"""
IndiceMaxPool::backward(out_features, features, out_bp, din, out_indices, inp_indices, stream);
"""
)
else
:
code
.
raw
(
f
"""
TV_THROW_RT_ERR("not implemented in cpu-only spconv!!! ")
"""
)
return
code
@
pccm
.
pybind
.
mark
@
pccm
.
cuda
.
static_function
def
maxpool_implicit_gemm_forward
(
self
):
...
...
@@ -597,7 +695,7 @@ class SpconvOps(pccm.Class):
"""
code
.
add_dependency
(
CustomThrustLib
,
TensorViewKernel
)
code
.
add_param_class
(
"cudakers"
,
C
uda
C
ommon
K
ernel
()
)
code
.
add_param_class
(
"cudakers"
,
self
.
c
uda
_c
ommon
_k
ernel
)
code
.
raw
(
f
"""
cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream);
if (indices.empty()){{
...
...
@@ -613,7 +711,7 @@ class SpconvOps(pccm.Class):
auto thrust_ctx = thrust::cuda::par.on(stream_cu);
thrust::stable_sort_by_key(thrust_ctx, ptr_tr, ptr_tr + data.dim(0), ptr_k, SmallOrEqualTo<uint32_t>());
}});
tv::ssprint("SORT BY KEY TIME", data.dim(0), timer.report() / 1000.0);
//
tv::ssprint("SORT BY KEY TIME", data.dim(0), timer.report() / 1000.0);
return indices;
"""
)
return
code
.
ret
(
"tv::Tensor"
)
...
...
@@ -646,7 +744,7 @@ class SpconvOps(pccm.Class):
}}
"""
code
.
add_dependency
(
CustomThrustLib
,
TensorViewKernel
)
code
.
add_param_class
(
"cudakers"
,
C
uda
C
ommon
K
ernel
()
)
code
.
add_param_class
(
"cudakers"
,
self
.
c
uda
_c
ommon
_k
ernel
)
if
not
use_allocator
:
code
.
raw
(
f
"""
ThrustCustomAllocatorV2 allocator{{alloc_func}};
...
...
@@ -715,7 +813,7 @@ class SpconvOps(pccm.Class):
}}
"""
code
.
add_dependency
(
CustomThrustLib
,
TensorViewKernel
)
code
.
add_param_class
(
"cudakers"
,
C
uda
C
ommon
K
ernel
()
)
code
.
add_param_class
(
"cudakers"
,
self
.
c
uda
_c
ommon
_k
ernel
)
code
.
raw
(
f
"""
cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream);
// auto timer = tv::CudaContextTimer<>();
...
...
@@ -774,7 +872,7 @@ class SpconvOps(pccm.Class):
}}
"""
code
.
add_dependency
(
CustomThrustLib
,
TensorViewKernel
)
code
.
add_param_class
(
"cudakers"
,
C
uda
C
ommon
K
ernel
()
)
code
.
add_param_class
(
"cudakers"
,
self
.
c
uda
_c
ommon
_k
ernel
)
if
not
use_allocator
:
code
.
raw
(
f
"""
ThrustCustomAllocatorV2 allocator{{alloc_func}};
...
...
@@ -1141,6 +1239,26 @@ class SpconvOps(pccm.Class):
"""
)
return
code
.
ret
(
"int"
)
@
pccm
.
pybind
.
mark
@
pccm
.
static_function
def
get_indice_gen_workspace_size
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"kv"
,
"size_t"
)
code
.
arg
(
"num_act_in"
,
"size_t"
)
code
.
arg
(
"num_act_out_bound"
,
"size_t"
)
code
.
arg
(
"subm, use_int64_hash_k"
,
"bool"
)
code
.
raw
(
f
"""
if (subm){{
return 2 * num_act_in * (use_int64_hash_k ? 2 : 3) * sizeof(int);
}}else{{
size_t pair_single_size = kv * num_act_in;
size_t ind_uniq_and_bkp_size = (pair_single_size + 1) * 2 * (use_int64_hash_k ? sizeof(int64_t) : sizeof(int32_t));
size_t hash_size = 2 * num_act_out_bound * (use_int64_hash_k ? 2 : 3) * sizeof(int);
return ind_uniq_and_bkp_size + hash_size;
}}
"""
)
return
code
.
ret
(
"std::size_t"
)
@
pccm
.
pybind
.
mark
@
pccm
.
static_function
def
get_indice_pairs_implicit_gemm
(
self
):
...
...
@@ -1154,6 +1272,8 @@ class SpconvOps(pccm.Class):
code
.
arg
(
"subm, transposed, is_train"
,
f
"bool"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
arg
(
"num_out_act_bound"
,
f
"int"
,
"-1"
)
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"""
throw std::runtime_error("this function can only be used with CUDA.")
...
...
@@ -1192,13 +1312,13 @@ class SpconvOps(pccm.Class):
int mask_split_count = is_mask_split ? 2 : 1;
tv::Tensor pair;
if (subm){{
pair = allocator.full_int(
{
pccm
.
literal
(
Spconv
Alloc
ator
Keys
.
Pair
)
}
,
pair = allocator.full_int(
{
pccm
.
literal
(
AllocKeys
.
Pair
)
}
,
{{2, kv, indices.dim(0)}}, -1, indices.dtype(), indices.device());
}}else{{
pair = allocator.full_int(
{
pccm
.
literal
(
Spconv
Alloc
ator
Keys
.
Pair
)
}
,
pair = allocator.full_int(
{
pccm
.
literal
(
AllocKeys
.
Pair
)
}
,
{{kv, indices.dim(0)}}, -1, indices.dtype(), indices.device());
}}
auto indice_num_per_loc = allocator.zeros(
{
pccm
.
literal
(
Spconv
Alloc
ator
Keys
.
IndiceNumPerLoc
)
}
,
auto indice_num_per_loc = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
IndiceNumPerLoc
)
}
,
{{kv}}, indices.dtype(), indices.device());
tv::Tensor mask_tensor = tv::zeros({{mask_split_count}}, tv::uint32, -1);
auto mask_tensor_ptr = mask_tensor.data_ptr<uint32_t>();
...
...
@@ -1213,39 +1333,48 @@ class SpconvOps(pccm.Class):
mask_tensor_ptr[1] = uint32_t(second);
}}
else{{
mask_tensor_ptr[
1
] = 0xffffffff;
mask_tensor_ptr[
0
] = 0xffffffff;
}}
tv::Tensor out_inds;
ThrustAllocator thrustalloc(allocator);
int num_act_out = 0;
if (subm){{
ExternalAllocator::guard_t hash_k_guard, hash_v_gurad, hash_kv_gurad;
out_inds = indices;
num_act_out = indices.dim(0);
int num_points = out_inds.dim(0);
tv::Tensor hash_k, hash_v;
if (use_int64_hash_k){{
hash_k_guard = allocator.empty_guard({{num_points * 2}}, tv::int64, 0);
hash_v_gurad = allocator.empty_guard({{num_points * 2}}, tv::int32, 0);
hash_k_guard = allocator.empty_guard({{num_points * 2}},
tv::int64, 0,
{
pccm
.
literal
(
AllocKeys
.
HashKOrKV
)
}
);
hash_v_gurad = allocator.empty_guard({{num_points * 2}},
tv::int32, 0,
{
pccm
.
literal
(
AllocKeys
.
HashV
)
}
);
hash_k = hash_k_guard->tensor;
hash_v = hash_v_gurad->tensor;
}}else{{
hash_kv_gurad = allocator.empty_guard({{2, num_points * 2}}, tv::int32, 0);
hash_kv_gurad = allocator.empty_guard({{2, num_points * 2}},
tv::int32, 0,
{
pccm
.
literal
(
AllocKeys
.
HashKOrKV
)
}
);
hash_k = hash_kv_gurad->tensor[0];
hash_v = hash_kv_gurad->tensor[1];
}}
auto pair_mask = allocator.empty(
{
pccm
.
literal
(
SpconvAllocatorKeys
.
PairMask
)
}
,
auto pair_mask = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
PairMask
)
}
,
{{mask_split_count, indices.dim(0)}}, tv::uint32, 0);
generate_subm_conv_inds(indices, hash_k, hash_v, pair, out_inds, indice_num_per_loc,
batch_size, input_dims, ksize, dilation, pair_mask, false, stream_int);
auto mask_argsort = allocator.empty(
{
pccm
.
literal
(
Spconv
Alloc
ator
Keys
.
MaskArgSort
)
}
,
{{mask_split_count, out_inds.dim(0)}}, tv::
u
int32, 0);
auto mask_argsort = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
MaskArgSort
)
}
,
{{mask_split_count, out_inds.dim(0)}}, tv::int32, 0);
for (int j = 0; j < mask_split_count; ++j){{
sort_1d_by_key_allocator_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int);
}}
}}else{{
auto pair_bwd = pair;
auto indice_pairs_uniq_guard = allocator.empty_guard({{int64_t(pair.numel() + 1)}}, indice_uniq_dtype, 0);
auto indice_pairs_uniq_bkp_guard = allocator.empty_guard({{int64_t(pair.numel() + 1)}}, indice_uniq_dtype, 0);
auto indice_pairs_uniq_guard = allocator.empty_guard({{int64_t(pair.numel() + 1)}},
indice_uniq_dtype, 0,
{
pccm
.
literal
(
AllocKeys
.
IndicePairsUniq
)
}
);
auto indice_pairs_uniq_bkp_guard = allocator.empty_guard({{int64_t(pair.numel() + 1)}},
indice_uniq_dtype, 0,
{
pccm
.
literal
(
AllocKeys
.
IndicePairsUniqBackup
)
}
);
auto indice_pairs_uniq = indice_pairs_uniq_guard->tensor;
generate_conv_inds_mask_stage1(indices, pair_bwd, indice_pairs_uniq,
...
...
@@ -1253,28 +1382,34 @@ class SpconvOps(pccm.Class):
stride, padding, dilation, transposed, stream_int);
indice_pairs_uniq_bkp_guard->tensor.copy_(indice_pairs_uniq, tvctx);
// TODO pytorch unique may be faster?
int num_act_out = apply_thrust_unique_to_indice_pairs_uniq(indice_pairs_uniq, thrustalloc, stream_int) - 1;
num_act_out = apply_thrust_unique_to_indice_pairs_uniq(indice_pairs_uniq, thrustalloc, stream_int);
if (num_out_act_bound > 0 && num_act_out > num_out_act_bound){{
num_act_out = num_out_act_bound;
}}
indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_act_out);
out_inds = allocator.empty(
{
pccm
.
literal
(
Spconv
Alloc
ator
Keys
.
OutIndices
)
}
,
out_inds = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
OutIndices
)
}
,
{{num_act_out, indices.dim(1)}}, indices.dtype(), 0);
auto pair_fwd = allocator.full_int(
{
pccm
.
literal
(
Spconv
Alloc
ator
Keys
.
PairFwd
)
}
,
auto pair_fwd = allocator.full_int(
{
pccm
.
literal
(
AllocKeys
.
PairFwd
)
}
,
{{kv, num_act_out}}, -1, indices.dtype(), indices.device());
auto pair_mask_fwd = allocator.zeros(
{
pccm
.
literal
(
Spconv
Alloc
ator
Keys
.
PairMask
)
}
,
auto pair_mask_fwd = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
PairMask
)
}
,
{{mask_split_count, num_act_out}}, tv::uint32, 0);
auto pair_mask_bwd = tv::Tensor();
if (is_train){{
pair_mask_bwd = allocator.zeros(
{
pccm
.
literal
(
Spconv
Alloc
ator
Keys
.
PairMaskBwd
)
}
,
pair_mask_bwd = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
PairMaskBwd
)
}
,
{{mask_split_count, indices.dim(0)}}, tv::uint32, 0);
}}
ExternalAllocator::guard_t hash_k_guard, hash_v_gurad, hash_kv_gurad;
tv::Tensor hash_k, hash_v;
if (use_int64_hash_k){{
hash_k_guard = allocator.empty_guard({{num_act_out * 2}}, tv::int64, 0);
hash_v_gurad = allocator.empty_guard({{num_act_out * 2}}, tv::int32, 0);
hash_k_guard = allocator.empty_guard({{num_act_out * 2}},
tv::int64, 0,
{
pccm
.
literal
(
AllocKeys
.
HashKOrKV
)
}
);
hash_v_gurad = allocator.empty_guard({{num_act_out * 2}},
tv::int32, 0,
{
pccm
.
literal
(
AllocKeys
.
HashV
)
}
);
hash_k = hash_k_guard->tensor;
hash_v = hash_v_gurad->tensor;
}}else{{
hash_kv_gurad = allocator.empty_guard({{2, num_act_out * 2}}, tv::int32, 0);
hash_kv_gurad = allocator.empty_guard({{2, num_act_out * 2}},
tv::int32, 0,
{
pccm
.
literal
(
AllocKeys
.
HashKOrKV
)
}
);
hash_k = hash_kv_gurad->tensor[0];
hash_v = hash_kv_gurad->tensor[1];
}}
...
...
@@ -1283,23 +1418,24 @@ class SpconvOps(pccm.Class):
out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
transposed, stream_int);
auto mask_argsort_fwd = allocator.empty(
{
pccm
.
literal
(
Spconv
Alloc
ator
Keys
.
MaskArgSort
)
}
,
{{mask_split_count, out_inds.dim(0)}}, tv::
u
int32, 0);
auto mask_argsort_fwd = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
MaskArgSort
)
}
,
{{mask_split_count, out_inds.dim(0)}}, tv::int32, 0);
tv::Tensor mask_argsort_bwd = tv::Tensor();
if (is_train){{
mask_argsort_bwd = allocator.zeros(
{
pccm
.
literal
(
Spconv
Alloc
ator
Keys
.
MaskArgSortBwd
)
}
,
{{mask_split_count, indices.dim(0)}}, tv::
u
int32, 0);
mask_argsort_bwd = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
MaskArgSortBwd
)
}
,
{{mask_split_count, indices.dim(0)}}, tv::int32, 0);
}}
if (is_mask_split){{
for (int j = 0; j < mask_split_count; ++j){{
auto mask_tensor_sub = mask_tensor.slice_first_axis(j, j + 1);
if (!is_train){{
sort_1d_by_key_split_allocator_v2(pair_mask_fwd[j], thrustalloc,
mask_tensor
[j]
, mask_argsort_fwd[j], stream_int);
mask_tensor
_sub
, mask_argsort_fwd[j], stream_int);
}}else{{
sort_1d_by_key_split_allocator_v2(pair_mask_fwd[j], thrustalloc,
mask_tensor
[j]
, mask_argsort_fwd[j], stream_int);
mask_tensor
_sub
, mask_argsort_fwd[j], stream_int);
sort_1d_by_key_split_allocator_v2(pair_mask_bwd[j], thrustalloc,
mask_tensor
[j]
, mask_argsort_bwd[j], stream_int);
mask_tensor
_sub
, mask_argsort_bwd[j], stream_int);
}}
}}
}}else{{
...
...
@@ -1314,9 +1450,9 @@ class SpconvOps(pccm.Class):
}}
}}
}}
return
mask_tensor
;
return
std::make_tuple(mask_tensor, num_act_out)
;
"""
)
return
code
.
ret
(
"tv::Tensor"
)
return
code
.
ret
(
"
std::tuple<
tv::Tensor
, int>
"
)
@
pccm
.
pybind
.
mark
@
pccm
.
static_function
...
...
@@ -1329,15 +1465,12 @@ class SpconvOps(pccm.Class):
code
.
arg
(
"algo"
,
"int"
)
code
.
arg
(
"ksize, stride, padding, dilation, out_padding"
,
f
"std::vector<int>"
)
code
.
arg
(
"subm, transposed"
,
f
"bool"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"""
throw std::runtime_error("this function can only be used with CUDA.")
"""
)
return
code
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
)
code
.
arg
(
"num_out_act_bound"
,
f
"int"
,
"-1"
)
code
.
raw
(
f
"""
int kv = std::accumulate(ksize.begin(), ksize.end(), 1, std::multiplies<int>());
TV_ASSERT_RT_ERR(kv <= 32, "currently only support ksize < 32");
auto conv_algo = static_cast<tv::gemm::SparseConvAlgo>(algo);
TV_ASSERT_RT_ERR(conv_algo == tv::gemm::SparseConvAlgo::kNative, "only support kNative");
...
...
@@ -1362,15 +1495,17 @@ class SpconvOps(pccm.Class):
}}
}}
tv::Tensor pair;
pair = allocator.full_int(
{
pccm
.
literal
(
Spconv
Alloc
ator
Keys
.
Pair
)
}
,
pair = allocator.full_int(
{
pccm
.
literal
(
AllocKeys
.
Pair
)
}
,
{{2, kv, indices.dim(0)}}, -1, indices.dtype(), indices.device());
auto indice_num_per_loc = allocator.zeros(
{
pccm
.
literal
(
Spconv
Alloc
ator
Keys
.
IndiceNumPerLoc
)
}
,
auto indice_num_per_loc = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
IndiceNumPerLoc
)
}
,
{{kv}}, indices.dtype(), indices.device());
tv::Tensor out_inds;
int num_act_out = -1;
"""
)
with
code
.
if_
(
"subm"
):
code
.
raw
(
f
"""
num_act_out = indices.dim(0);
if (indices.is_cpu()){{
generate_subm_conv_inds_cpu(indices, pair, out_inds, indice_num_per_loc,
batch_size, input_dims, ksize, dilation);
...
...
@@ -1384,12 +1519,15 @@ class SpconvOps(pccm.Class):
int num_points = out_inds.dim(0);
tv::Tensor hash_k, hash_v;
if (use_int64_hash_k){{
hash_k_guard = allocator.empty_guard({{num_points * 2}}, tv::int64, 0);
hash_v_gurad = allocator.empty_guard({{num_points * 2}}, tv::int32, 0);
hash_k_guard = allocator.empty_guard({{num_points * 2}},
tv::int64, 0,
{
pccm
.
literal
(
AllocKeys
.
HashKOrKV
)
}
);
hash_v_gurad = allocator.empty_guard({{num_points * 2}},
tv::int32, 0,
{
pccm
.
literal
(
AllocKeys
.
HashV
)
}
);
hash_k = hash_k_guard->tensor;
hash_v = hash_v_gurad->tensor;
}}else{{
hash_kv_gurad = allocator.empty_guard({{2, num_points * 2}}, tv::int32, 0);
hash_kv_gurad = allocator.empty_guard({{2, num_points * 2}},
tv::int32, 0,
{
pccm
.
literal
(
AllocKeys
.
HashKOrKV
)
}
);
hash_k = hash_kv_gurad->tensor[0];
hash_v = hash_kv_gurad->tensor[1];
}}
...
...
@@ -1406,10 +1544,10 @@ class SpconvOps(pccm.Class):
with
code
.
else_
():
code
.
raw
(
f
"""
if (indices.is_cpu()){{
out_inds = allocator.empty(
{
pccm
.
literal
(
Spconv
Alloc
ator
Keys
.
OutIndices
)
}
,
TV_ASSERT_RT_ERR(num_out_act_bound <= 0, "cpu algo don't support out bound")
out_inds = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
OutIndices
)
}
,
{{kv * indices.dim(0), indices.dim(1)}}, indices.dtype(), -1);
generate_conv_inds_cpu(indices, pair, out_inds, indice_num_per_loc,
num_act_out =
generate_conv_inds_cpu(indices, pair, out_inds, indice_num_per_loc,
batch_size, out_shape, input_dims, ksize,
stride, padding, dilation, transposed);
}}
...
...
@@ -1422,9 +1560,13 @@ class SpconvOps(pccm.Class):
auto tvctx = tv::Context();
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int));
auto indice_pairs_uniq_guard = allocator.empty_guard({{int64_t(pair.numel() + 1)}}, indice_uniq_dtype, 0);
auto indice_pairs_uniq_guard = allocator.empty_guard(
{{int64_t(pair.numel() / 2 + 1)}}, indice_uniq_dtype, 0,
{
pccm
.
literal
(
AllocKeys
.
IndicePairsUniq
)
}
);
auto indice_pairs_uniq = indice_pairs_uniq_guard->tensor;
auto indice_pairs_uniq_bkp_guard = allocator.empty_guard({{int64_t(pair.numel() + 1)}}, indice_uniq_dtype, 0);
auto indice_pairs_uniq_bkp_guard = allocator.empty_guard(
{{int64_t(pair.numel() / 2 + 1)}}, indice_uniq_dtype, 0,
{
pccm
.
literal
(
AllocKeys
.
IndicePairsUniqBackup
)
}
);
generate_conv_inds_stage1(indices, pair, indice_pairs_uniq,
indice_num_per_loc, batch_size, out_shape, input_dims, ksize,
...
...
@@ -1432,27 +1574,35 @@ class SpconvOps(pccm.Class):
indice_pairs_uniq_bkp_guard->tensor.copy_(indice_pairs_uniq, tvctx);
// TODO pytorch unique may be faster?
int num_act_out = apply_thrust_unique_to_indice_pairs_uniq(indice_pairs_uniq, thrustalloc, stream_int) - 1;
num_act_out = apply_thrust_unique_to_indice_pairs_uniq(indice_pairs_uniq, thrustalloc, stream_int);
bool use_bound_algo = false;
if (num_out_act_bound > 0 && num_act_out > num_out_act_bound){{
num_act_out = num_out_act_bound;
use_bound_algo = true;
}}
indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_act_out);
out_inds = allocator.empty(
{
pccm
.
literal
(
Spconv
Alloc
ator
Keys
.
OutIndices
)
}
,
out_inds = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
OutIndices
)
}
,
{{num_act_out, indices.dim(1)}}, indices.dtype(), 0);
ExternalAllocator::guard_t hash_k_guard, hash_v_gurad, hash_kv_gurad;
tv::Tensor hash_k, hash_v;
if (use_int64_hash_k){{
hash_k_guard = allocator.empty_guard({{num_act_out * 2}}, tv::int64, 0);
hash_v_gurad = allocator.empty_guard({{num_act_out * 2}}, tv::int32, 0);
hash_k_guard = allocator.empty_guard({{num_act_out * 2}},
tv::int64, 0,
{
pccm
.
literal
(
AllocKeys
.
HashKOrKV
)
}
);
hash_v_gurad = allocator.empty_guard({{num_act_out * 2}},
tv::int32, 0,
{
pccm
.
literal
(
AllocKeys
.
HashV
)
}
);
hash_k = hash_k_guard->tensor;
hash_v = hash_v_gurad->tensor;
}}else{{
hash_kv_gurad = allocator.empty_guard({{2, num_act_out * 2}}, tv::int32, 0);
hash_kv_gurad = allocator.empty_guard({{2, num_act_out * 2}},
tv::int32, 0,
{
pccm
.
literal
(
AllocKeys
.
HashKOrKV
)
}
);
hash_k = hash_kv_gurad->tensor[0];
hash_v = hash_kv_gurad->tensor[1];
}}
generate_conv_inds_stage2(indices, hash_k, hash_v, pair,
num_act_out =
generate_conv_inds_stage2(indices, hash_k, hash_v, pair,
indice_pairs_uniq, indice_pairs_uniq_bkp_guard->tensor,
out_inds, num_act_out,
out_inds,
indice_num_per_loc,
num_act_out,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
transposed, stream_int);
transposed, stream_int
, use_bound_algo
);
}}
"""
)
else
:
...
...
@@ -1462,18 +1612,6 @@ class SpconvOps(pccm.Class):
}}
"""
)
code
.
raw
(
f
"""
return;
return
num_act_out
;
"""
)
return
code
@
pccm
.
pybind
.
mark
@
pccm
.
static_function
def
test_allocator
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"allocator"
,
"ExternalAllocator&"
)
code
.
raw
(
f
"""
auto guard = allocator.zeros_guard({{1, 2, 3}}, tv::int32, 0);
tv::ssprint("????");
"""
)
return
code
\ No newline at end of file
return
code
.
ret
(
"int"
)
spconv/csrc/sparse/alloc.py
View file @
899008fa
import
pccm
from
cumm.common
import
TensorView
,
TensorViewCPU
,
TensorViewKernel
,
ThrustLib
from
spconv.constants
import
AllocKeys
class
ExternalAllocatorGuard
(
pccm
.
Class
):
def
__init__
(
self
):
super
().
__init__
()
...
...
@@ -51,6 +53,9 @@ class ExternalAllocator(pccm.Class):
code
.
arg
(
"shape"
,
"std::vector<int64_t>"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
code
.
arg
(
"is_temp_memory"
,
"bool"
,
"false"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
...
...
@@ -61,6 +66,9 @@ class ExternalAllocator(pccm.Class):
code
.
arg
(
"shape"
,
"std::vector<int64_t>"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
code
.
arg
(
"is_temp_memory"
,
"bool"
,
"false"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
...
...
@@ -72,6 +80,9 @@ class ExternalAllocator(pccm.Class):
code
.
arg
(
"value"
,
"int"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
code
.
arg
(
"is_temp_memory"
,
"bool"
,
"false"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
...
...
@@ -83,6 +94,15 @@ class ExternalAllocator(pccm.Class):
code
.
arg
(
"value"
,
"float"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
code
.
arg
(
"is_temp_memory"
,
"bool"
,
"false"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
@
pccm
.
member_function
(
virtual
=
True
,
pure_virtual
=
True
)
def
get_tensor_by_name
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"name"
,
"std::string"
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
...
...
@@ -105,9 +125,11 @@ class ExternalAllocator(pccm.Class):
code
.
arg
(
"shape"
,
"std::vector<int64_t>"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
code
.
arg
(
"name"
,
"std::string"
,
"
\"\"
"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
)
code
.
raw
(
f
"""
// "" means temp memory
auto ten = zeros(
""
, shape, dtype, device);
auto ten = zeros(
name
, shape, dtype, device
, true, stream
);
return std::make_
{
self
.
ptr_type
}
<ExternalAllocatorGuard>(ten, [this](tv::Tensor ten){{
this->free(ten);
}});
...
...
@@ -120,8 +142,10 @@ class ExternalAllocator(pccm.Class):
code
.
arg
(
"shape"
,
"std::vector<int64_t>"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
code
.
arg
(
"name"
,
"std::string"
,
"
\"\"
"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
)
code
.
raw
(
f
"""
auto ten = empty(
""
, shape, dtype, device);
auto ten = empty(
name
, shape, dtype, device
, true, stream
);
return std::make_
{
self
.
ptr_type
}
<ExternalAllocatorGuard>(ten, [this](tv::Tensor ten){{
this->free(ten);
}});
...
...
@@ -135,8 +159,10 @@ class ExternalAllocator(pccm.Class):
code
.
arg
(
"value"
,
"int"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
code
.
arg
(
"name"
,
"std::string"
,
"
\"\"
"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
)
code
.
raw
(
f
"""
auto ten = full_int(
""
, shape, value, dtype, device);
auto ten = full_int(
name
, shape, value, dtype, device
, true, stream
);
return std::make_
{
self
.
ptr_type
}
<ExternalAllocatorGuard>(ten, [this](tv::Tensor ten){{
this->free(ten);
}});
...
...
@@ -150,8 +176,10 @@ class ExternalAllocator(pccm.Class):
code
.
arg
(
"value"
,
"int"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
code
.
arg
(
"name"
,
"std::string"
,
"
\"\"
"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
)
code
.
raw
(
f
"""
auto ten = full_float(
""
, shape, value, dtype, device);
auto ten = full_float(
name
, shape, value, dtype, device
, true, stream
);
return std::make_
{
self
.
ptr_type
}
<ExternalAllocatorGuard>(ten, [this](tv::Tensor t){{
this->free(t);
}});
...
...
@@ -179,7 +207,7 @@ class ThrustAllocator(pccm.Class):
code
.
arg
(
"num_bytes"
,
"std::ptrdiff_t"
)
code
.
ret
(
"char*"
)
code
.
raw
(
f
"""
auto ten = allocator_.empty(
""
, {{num_bytes}}, tv::uint8, 0);
auto ten = allocator_.empty(
{
pccm
.
literal
(
AllocKeys
.
ThrustTemp
)
}
, {{num_bytes}}, tv::uint8, 0);
return reinterpret_cast<char*>(ten.raw_data());
"""
)
return
code
...
...
@@ -193,3 +221,158 @@ class ThrustAllocator(pccm.Class):
return allocator_.free_noexcept(tv::from_blob(ptr, {{num_bytes}}, tv::uint8, 0));
"""
)
return
code
class
StaticAllocator
(
ExternalAllocator
):
"""a simple allocator for tensorrt plugin.
"""
def
__init__
(
self
):
super
().
__init__
()
self
.
add_dependency
(
TensorView
)
self
.
add_member
(
"tensor_dict_"
,
"std::unordered_map<std::string, tv::Tensor>"
)
self
.
add_member
(
"repr_"
,
"std::string"
)
self
.
add_member
(
"thrust_tmp_tensor_"
,
"tv::Tensor"
)
self
.
grow
=
1.5
@
pccm
.
pybind
.
mark
@
pccm
.
constructor
def
ctor
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"tensor_dict"
,
"std::unordered_map<std::string, tv::Tensor>"
)
code
.
ctor_init
(
"tensor_dict_"
,
"tensor_dict"
)
code
.
raw
(
f
"""
std::stringstream ss;
for (auto& p : tensor_dict){{
tv::ssprint(ss, p.first, p.second.shape(), tv::dtype_str(p.second.dtype()), "
\\
n");
}}
repr_ = ss.str();
"""
)
return
code
@
pccm
.
member_function
(
virtual
=
True
)
def
_get_raw_and_check
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"name"
,
"std::string"
)
code
.
arg
(
"shape"
,
"std::vector<int64_t>"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
code
.
raw
(
f
"""
auto res = get_tensor_by_name(name);
size_t total = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
TV_ASSERT_RT_ERR(res.nbytes() >= total * tv::bit_size(tv::DType(dtype))
&& res.device() == device, "alloc failed", shape, res.shape());
return tv::from_blob(res.raw_data(), shape, dtype, device);
"""
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
@
pccm
.
member_function
(
virtual
=
True
)
def
zeros
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"name"
,
"std::string"
)
code
.
arg
(
"shape"
,
"std::vector<int64_t>"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
code
.
arg
(
"is_temp_memory"
,
"bool"
,
"false"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
)
code
.
raw
(
f
"""
auto tvctx = tv::Context();
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream));
auto blob = _get_raw_and_check(name, shape, dtype, device);
return blob.zero_(tvctx);
"""
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
@
pccm
.
member_function
(
virtual
=
True
)
def
empty
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"name"
,
"std::string"
)
code
.
arg
(
"shape"
,
"std::vector<int64_t>"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
code
.
arg
(
"is_temp_memory"
,
"bool"
,
"false"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
)
code
.
raw
(
f
"""
if (name ==
{
pccm
.
literal
(
AllocKeys
.
ThrustTemp
)
}
){{
// thrust tmp shouldn't inside tensor_dict. use a simple method to allocate
// we assume each allocator always handle one stream
// so we can just use one tensor
tv::Tensor res = thrust_tmp_tensor_;
if (res.empty()){{
res = tv::empty(shape, dtype, device);
thrust_tmp_tensor_ = res;
}}
if (shape[0] > thrust_tmp_tensor_.dim(0)){{
res = tv::empty({{int64_t(shape[0] *
{
self
.
grow
}
)}}, dtype, device);
thrust_tmp_tensor_ = res;
}}
return res;
}}else{{
auto blob = _get_raw_and_check(name, shape, dtype, device);
return blob;
}}
"""
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
@
pccm
.
member_function
(
virtual
=
True
)
def
full_int
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"name"
,
"std::string"
)
code
.
arg
(
"shape"
,
"std::vector<int64_t>"
)
code
.
arg
(
"value"
,
"int"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
code
.
arg
(
"is_temp_memory"
,
"bool"
,
"false"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
)
code
.
raw
(
f
"""
auto tvctx = tv::Context();
auto blob = _get_raw_and_check(name, shape, dtype, device);
return blob.fill_(tvctx, value);
"""
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
@
pccm
.
member_function
(
virtual
=
True
)
def
full_float
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"name"
,
"std::string"
)
code
.
arg
(
"shape"
,
"std::vector<int64_t>"
)
code
.
arg
(
"value"
,
"float"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
code
.
arg
(
"is_temp_memory"
,
"bool"
,
"false"
)
code
.
arg
(
"stream"
,
"std::uintptr_t"
,
"0"
)
code
.
raw
(
f
"""
auto blob = _get_raw_and_check(name, shape, dtype, device);
return blob.fill_(tvctx, value);
"""
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
@
pccm
.
member_function
(
virtual
=
True
)
def
get_tensor_by_name
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"name"
,
"std::string"
)
code
.
raw
(
f
"""
TV_ASSERT_RT_ERR(tensor_dict_.find(name) != tensor_dict_.end(), "can't find", name, "exists:
\\
n", repr_);
return tensor_dict_.at(name);
"""
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
@
pccm
.
member_function
(
virtual
=
True
)
def
free
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"ten"
,
"tv::Tensor"
)
return
code
@
pccm
.
pybind
.
mark
@
pccm
.
member_function
(
virtual
=
True
)
def
free_noexcept
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"ten"
,
"tv::Tensor"
)
return
code
spconv/csrc/sparse/convops.py
View file @
899008fa
from
typing
import
Optional
import
pccm
from
cumm.
gemm.main
import
GemmMainUnitTest
from
cumm.common
import
GemmBasicHost
,
NlohmannJson
,
TensorView
from
cumm.
constants
import
CUMM_CPU_ONLY_BUILD
from
cumm.conv.main
import
ConvMainUnitTest
from
cumm.gemm.algospec.core
import
(
_GEMM_MIN_ARCH_TO_ALGO
,
GemmAlgo
,
ShuffleStrideType
,
get_available_algo_str_from_arch
,
get_min_arch_of_algo_str
)
from
cumm.gemm.main
import
GemmMainUnitTest
from
spconv.constants
import
NDIM_DONT_CARE
,
SPCONV_BWD_SPLITK
,
AllocKeys
from
spconv.core
import
AlgoHint
,
ConvAlgo
from
spconv.csrc.sparse.gather
import
GatherCPU
from
.alloc
import
ExternalAllocator
from
spconv.core
import
ConvAlgo
from
spconv.constants
import
SpconvAllocatorKeys
from
cumm.constants
import
CUMM_CPU_ONLY_BUILD
from
cumm.common
import
GemmBasicHost
,
TensorView
,
NlohmannJson
from
cumm.common
import
CompileInfo
class
ExternalSpconvMatmul
(
pccm
.
Class
):
"""a helper class to warp matmul operations
because we don't want to implement matmul
(link to cublas/mkl/pytorch) in python package.
"""
def
__init__
(
self
):
super
().
__init__
()
self
.
add_dependency
(
TensorView
)
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
@
pccm
.
member_function
(
virtual
=
True
)
def
indice_conv_init_gemm
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"features_n, filters_n"
,
"std::string"
)
code
.
arg
(
"all_weight_is_krsc, is_kc_not_ck"
,
"bool"
)
code
.
arg
(
"kv_center, out_channel"
,
"int"
)
code
.
arg
(
"stream_int"
,
"std::uintptr_t"
,
"0"
)
code
.
raw
(
f
"""
TV_THROW_RT_ERR("not implemented, override this and use preferred blas!!!");
"""
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
@
pccm
.
member_function
(
virtual
=
True
)
def
indice_conv_cpu_gemm
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"inp_buffer_n, out_buffer_n, filters_n"
,
"std::string"
)
code
.
arg
(
"all_weight_is_krsc, is_kc_not_ck"
,
"bool"
)
code
.
arg
(
"nhot, index"
,
"int"
)
code
.
raw
(
f
"""
TV_THROW_RT_ERR("not implemented, override this and use preferred cpu blas!!!");
"""
)
return
code
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
@
pccm
.
member_function
(
virtual
=
True
)
def
indice_conv_bwd_init_gemm
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"features_n, filters_n, out_bp_n, dfilters_n"
,
"std::string"
)
code
.
arg
(
"all_weight_is_krsc, is_kc_not_ck"
,
"bool"
)
code
.
arg
(
"kv_center"
,
"int"
)
code
.
arg
(
"stream_int"
,
"std::uintptr_t"
,
"0"
)
code
.
raw
(
f
"""
TV_THROW_RT_ERR("not implemented, override this and use preferred blas!!!");
"""
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
@
pccm
.
member_function
(
virtual
=
True
)
def
indice_conv_bwd_cpu_gemm
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"inp_buffer_n, out_buffer_n, filters_n, dfilters_n"
,
"std::string"
)
code
.
arg
(
"all_weight_is_krsc, is_kc_not_ck"
,
"bool"
)
code
.
arg
(
"nhot, index"
,
"int"
)
code
.
raw
(
f
"""
TV_THROW_RT_ERR("not implemented, override this and use preferred cpu blas!!!");
"""
)
return
code
class
SimpleExternalSpconvMatmul
(
ExternalSpconvMatmul
):
"""a helper class to warp matmul operations
because we don't want to implement matmul
(link to cublas/mkl/pytorch) in python package.
"""
def
__init__
(
self
):
super
().
__init__
()
self
.
add_dependency
(
TensorView
,
ExternalAllocator
)
self
.
build_meta
.
add_libraries
(
"cublasLt"
)
self
.
add_include
(
"cublasLt.h"
)
self
.
add_member
(
"alloc_"
,
"ExternalAllocator&"
)
self
.
add_member
(
"handle_"
,
"cublasLtHandle_t"
,
"0"
)
@
pccm
.
pybind
.
mark
@
pccm
.
constructor
def
ctor
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"alloc"
,
"ExternalAllocator&"
)
code
.
ctor_init
(
"alloc_"
,
"alloc"
)
code
.
raw
(
f
"""
auto stat = cublasLtCreate(&handle_);
TV_ASSERT_RT_ERR(CUBLAS_STATUS_SUCCESS == stat, "err");
"""
)
return
code
@
pccm
.
destructor
def
destructor
(
self
):
code
=
pccm
.
code
()
code
.
raw
(
f
"""
if (handle_){{
cublasLtDestroy(handle_);
}}
"""
)
return
code
@
pccm
.
static_function
def
check_cublas_status
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"status"
,
"cublasStatus_t"
)
code
.
raw
(
f
"""
if (status != CUBLAS_STATUS_SUCCESS) {{
printf("cuBLAS API failed with status %d
\\
n", status);
throw std::logic_error("cuBLAS API failed");
}}
"""
)
return
code
@
pccm
.
static_function
def
tv_dtype_to_blaslt
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"dtype"
,
"tv::DType"
)
code
.
raw
(
f
"""
switch (dtype) {{
case tv::float32:
return CUDA_R_32F;
case tv::float16:
return CUDA_R_16F;
case tv::int32:
return CUDA_R_32I;
case tv::int8:
return CUDA_R_8I;
case tv::uint32:
return CUDA_R_32U;
default:
return CUDA_R_32F;
}}
"""
)
return
code
.
ret
(
"decltype(CUDA_R_16F)"
)
@
pccm
.
static_function
(
inline
=
True
)
def
tv_dtype_to_compute
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"dtype"
,
"tv::DType"
)
with
code
.
macro_if_
(
"CUDART_VERSION >= 11000"
):
code
.
raw
(
f
"""
switch (dtype) {{
case tv::float32:
return CUBLAS_COMPUTE_32F;
case tv::float16:
return CUBLAS_COMPUTE_16F;
case tv::int32:
return CUBLAS_COMPUTE_32I;
case tv::int8:
return CUBLAS_COMPUTE_32F;
case tv::uint32:
return CUBLAS_COMPUTE_32F;
default:
return CUBLAS_COMPUTE_32F;
}}
"""
)
with
code
.
macro_else_
():
code
.
raw
(
f
"""
switch (dtype) {{
case tv::float32:
return CUDA_R_32F;
case tv::float16:
return CUDA_R_16F;
case tv::int32:
return CUDA_R_32I;
case tv::int8:
return CUDA_R_8I;
case tv::uint32:
return CUDA_R_32U;
default:
return CUDA_R_32F;
}}
"""
)
code
.
macro_endif_
()
return
code
.
ret
(
"decltype(auto)"
)
@
pccm
.
static_function
def
matmul_colmajor
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"handle"
,
"cublasLtHandle_t"
)
code
.
arg
(
"stream"
,
"cudaStream_t"
)
code
.
arg
(
"a, b, c"
,
"tv::Tensor"
)
code
.
arg
(
"transA, transB"
,
"bool"
)
code
.
raw
(
f
"""
bool transC = false;
auto m = a.dim(int(!transA));
auto k = a.dim(int(transA));
auto k2 = b.dim(int(!transB));
auto n = b.dim(int(transB));
TV_ASSERT_INVALID_ARG(k == k2, "error");
TV_ASSERT_INVALID_ARG(a.dtype() == b.dtype(), "error");
tv::TensorShape c_shape;
if (transC) {{
c_shape = {{m, n}};
}} else {{
c_shape = {{n, m}};
}}
if (c.empty()) {{
c = tv::Tensor(c_shape, a.dtype(), a.device());
}} else {{
TV_ASSERT_INVALID_ARG(c.dim(0) == c_shape[0] && c.dim(1) == c_shape[1],
"error");
}}
cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
decltype(CUDA_R_16F) scalarType = CUDA_R_16F;
#if CUDART_VERSION >= 11000
decltype(CUBLAS_COMPUTE_32F) computeType = CUBLAS_COMPUTE_32F;
#endif
if (a.dtype() == tv::float16 && b.dtype() == tv::float16 &&
c.dtype() == tv::float16) {{
scalarType = CUDA_R_16F;
#if CUDART_VERSION >= 11000
computeType = CUBLAS_COMPUTE_16F;
#endif
}} else if (a.dtype() == tv::float32 && b.dtype() == tv::float32 &&
c.dtype() == tv::float16) {{
scalarType = CUDA_R_32F;
#if CUDART_VERSION >= 11000
computeType = CUBLAS_COMPUTE_32F;
#endif
}} else if (a.dtype() == tv::float32 && b.dtype() == tv::float32 &&
c.dtype() == tv::float32) {{
scalarType = CUDA_R_32F;
#if CUDART_VERSION >= 11000
computeType = CUBLAS_COMPUTE_32F;
#endif
}} else if (a.dtype() == tv::float16 && b.dtype() == tv::float16 &&
c.dtype() == tv::float32) {{
scalarType = CUDA_R_32F;
#if CUDART_VERSION >= 11000
computeType = CUBLAS_COMPUTE_32F;
#endif
}} else {{
TV_THROW_RT_ERR("unsupported");
}}
#if CUDART_VERSION >= 11000
check_cublas_status(
cublasLtMatmulDescCreate(&operationDesc, computeType, scalarType));
#else
check_cublas_status(cublasLtMatmulDescCreate(&operationDesc, scalarType));
#endif
cublasOperation_t transa = !transA ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t transb = !transB ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t transc = !transC ? CUBLAS_OP_N : CUBLAS_OP_T;
check_cublas_status(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)));
check_cublas_status(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transb)));
// check_cublas_status(cublasLtMatmulDescSetAttribute(
// operationDesc, CUBLASLT_MATMUL_DESC_TRANSC, &transc,
// sizeof(transc)));
check_cublas_status(cublasLtMatrixLayoutCreate(
&Adesc, tv_dtype_to_blaslt(a.dtype()), transa == CUBLAS_OP_N ? m : k,
transa == CUBLAS_OP_N ? k : m, a.stride(0)));
check_cublas_status(cublasLtMatrixLayoutCreate(
&Bdesc, tv_dtype_to_blaslt(b.dtype()), transb == CUBLAS_OP_N ? k : n,
transb == CUBLAS_OP_N ? n : k, b.stride(0)));
// check_cublas_status(cublasLtMatrixLayoutCreate(
// &Cdesc, tv_dtype_to_blaslt(c.dtype()), transc == CUBLAS_OP_N ? m : n,
// transc == CUBLAS_OP_N ? n : m, c.dim(0)));
check_cublas_status(cublasLtMatrixLayoutCreate(
&Cdesc, tv_dtype_to_blaslt(c.dtype()), m, n, c.stride(0)));
cublasLtMatmulHeuristicResult_t heuristicResult = {{}};
cublasLtMatmulPreference_t preference = NULL;
check_cublas_status(cublasLtMatmulPreferenceCreate(&preference));
size_t workspaceSize = 0;
check_cublas_status(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize,
sizeof(workspaceSize)));
int returnedResults = 0;
check_cublas_status(cublasLtMatmulAlgoGetHeuristic(
handle, operationDesc, Adesc, Bdesc, Cdesc, Cdesc, preference, 1,
&heuristicResult, &returnedResults));
if (returnedResults == 0) {{
check_cublas_status(CUBLAS_STATUS_NOT_SUPPORTED);
}}
int alpha_storage[4];
int beta_storage[4];
if (scalarType == CUDA_R_32F) {{
*(reinterpret_cast<float *>(alpha_storage)) = 1.0f;
*(reinterpret_cast<float *>(beta_storage)) = 0.0f;
}} else if (scalarType == CUDA_R_16F) {{
*(reinterpret_cast<__half *>(alpha_storage)) = __half(1.0f);
*(reinterpret_cast<__half *>(beta_storage)) = __half(0.0f);
}} else {{
TV_THROW_RT_ERR("unsupported");
}}
check_cublas_status(cublasLtMatmul(
handle, operationDesc, alpha_storage, a.raw_data(), Adesc, b.raw_data(),
Bdesc, beta_storage, c.raw_data(), Cdesc, c.raw_data(), Cdesc,
&heuristicResult.algo, nullptr, 0, stream));
if (preference)
check_cublas_status(cublasLtMatmulPreferenceDestroy(preference));
if (Cdesc)
check_cublas_status(cublasLtMatrixLayoutDestroy(Cdesc));
if (Bdesc)
check_cublas_status(cublasLtMatrixLayoutDestroy(Bdesc));
if (Adesc)
check_cublas_status(cublasLtMatrixLayoutDestroy(Adesc));
if (operationDesc)
check_cublas_status(cublasLtMatmulDescDestroy(operationDesc));
return;
"""
)
return
code
@
pccm
.
static_function
def
matmul
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"handle"
,
"cublasLtHandle_t"
)
code
.
arg
(
"stream"
,
"cudaStream_t"
)
code
.
arg
(
"a, b, c"
,
"tv::Tensor"
)
code
.
arg
(
"transA, transB"
,
"bool"
)
code
.
raw
(
f
"""
return matmul_colmajor(handle, stream, b, a, c, transB, transA);
"""
)
return
code
@
pccm
.
member_function
def
indice_conv_init_gemm
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"features_n, filters_n"
,
"std::string"
)
code
.
arg
(
"all_weight_is_krsc, is_kc_not_ck"
,
"bool"
)
code
.
arg
(
"kv_center, out_channel"
,
"int"
)
code
.
arg
(
"stream_int"
,
"std::uintptr_t"
)
code
.
raw
(
f
"""
auto features = alloc_.get_tensor_by_name(features_n);
auto filters = alloc_.get_tensor_by_name(filters_n);
TV_ASSERT_RT_ERR(!features.is_cpu(), "only supprt cuda");
auto out_features = alloc_.empty(
{
pccm
.
literal
(
AllocKeys
.
OutFeatures
)
}
,
{{features.dim(0), out_channel}}, features.dtype(), features.device());
if (!all_weight_is_krsc){{
filters = filters.view(-1, filters.dim(-2), filters.dim(-1));
if (!is_kc_not_ck){{
matmul(handle_, reinterpret_cast<cudaStream_t>(stream_int),
features, filters[kv_center], out_features, false, false);
}}else{{
matmul(handle_, reinterpret_cast<cudaStream_t>(stream_int),
features, filters[kv_center], out_features, false, true);
}}
}}else{{
filters = filters.view(out_channel, -1, filters.dim(-1));
matmul(handle_, reinterpret_cast<cudaStream_t>(stream_int),
features, filters.select(1, kv_center), out_features, false, true);
}}
return out_features;
"""
)
return
code
.
ret
(
"tv::Tensor"
)
class
GemmTuneResult
(
pccm
.
Class
,
pccm
.
pybind
.
PybindClassMixin
):
...
...
@@ -21,8 +388,8 @@ class GemmTuneResult(pccm.Class, pccm.pybind.PybindClassMixin):
@
pccm
.
member_function
def
is_valid
(
self
):
code
=
pccm
.
code
()
code
.
raw
(
f
"return splitk > 0 && std::get<0>(arch) > 0"
)
return
code
code
.
raw
(
f
"return splitk > 0 && std::get<0>(arch) > 0
;
"
)
return
code
.
ret
(
"bool"
)
@
pccm
.
pybind
.
mark
@
pccm
.
constructor
...
...
@@ -61,7 +428,10 @@ class ConvTuneResult(pccm.Class, pccm.pybind.PybindClassMixin):
@
pccm
.
constructor
def
defaultctor
(
self
):
code
=
pccm
.
code
()
code
.
ctor_init
(
"algo_desp"
,
"tv::gemm::ConvAlgoDesp()"
)
code
.
ctor_init
(
"algo_desp"
,
f
"tv::gemm::ConvAlgoDesp(
{
NDIM_DONT_CARE
}
, tv::gemm::ConvOpType::kForward)"
)
code
.
ctor_init
(
"arch"
,
"std::make_tuple(-1, -1)"
)
code
.
ctor_init
(
"splitk"
,
"-1"
)
return
code
...
...
@@ -84,124 +454,1738 @@ class ConvTuneResult(pccm.Class, pccm.pybind.PybindClassMixin):
@
pccm
.
member_function
def
is_valid
(
self
):
code
=
pccm
.
code
()
code
.
raw
(
f
"return splitk > 0 && std::get<0>(arch) > 0"
)
return
code
code
.
raw
(
f
"return splitk > 0 && std::get<0>(arch) > 0;"
)
return
code
.
ret
(
"bool"
)
class
GemmTunerSimple
(
pccm
.
ParameterizedClass
):
def
__init__
(
self
,
gemm_cu
:
GemmMainUnitTest
,
conv_cu
:
ConvMainUnitTest
):
def
__init__
(
self
,
gemm_cu
:
Optional
[
GemmMainUnitTest
]):
super
().
__init__
()
self
.
add_dependency
(
ExternalAllocator
,
GemmTuneResult
,
ConvTuneResult
,
TensorView
)
self
.
add_dependency
(
ExternalAllocator
,
GemmTuneResult
,
TensorView
,
GemmBasicHost
,
CompileInfo
)
if
gemm_cu
is
not
None
:
self
.
add_param_class
(
"gemm"
,
gemm_cu
,
"GemmMain"
)
self
.
add_param_class
(
"conv"
,
conv_cu
,
"ConvMain"
)
if
not
CUMM_CPU_ONLY_BUILD
:
assert
gemm_cu
is
not
None
self
.
add_include
(
"tensorview/profile/cuda_profiler.h"
)
self
.
add_include
(
"tensorview/utility/tuplehash.h"
)
self
.
add_include
(
"mutex"
)
self
.
add_typedef
(
"static_key_t"
,
"std::tuple<bool, bool, bool, int, "
"int, int, int, std::string>"
)
self
.
add_typedef
(
"algo_cache_key_t"
,
"std::tuple<int, "
"int, int, int, int>"
)
self
.
add_member
(
"desps_"
,
"std::vector<tv::gemm::GemmAlgoDesp>"
)
self
.
add_member
(
"nvrtc_progs_"
,
"std::unordered_map<std::string, tv::NVRTCProgram>"
)
self
.
add_member
(
"nvrtc_caches_"
,
"std::unordered_map<std::tuple<std::string, int, int, std::uintptr_t>, tv::NVRTCModule>"
)
self
.
add_member
(
"static_key_to_desps_"
,
"std::unordered_map<static_key_t, std::vector<tv::gemm::GemmAlgoDesp>>"
)
self
.
add_member
(
"prebuilt_names_"
,
"std::unordered_set<std::string>"
)
self
.
add_member
(
"mutex_"
,
"std::mutex"
)
self
.
add_member
(
"nk_forward_cache_, nk_dgrad_cache_, mn_cache_"
,
"std::unordered_map<algo_cache_key_t, GemmTuneResult>"
)
@
pccm
.
pybind
.
mark
@
pccm
.
constructor
def
ctor
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"desps"
,
"std::vector<tv::gemm::GemmAlgoDesp>"
)
code
.
arg
(
"nvrtc_progs"
,
"std::unordered_map<std::string, std::string>"
)
code
.
ctor_init
(
"desps_"
,
"desps"
)
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"TV_THROW_RT_ERR(
\"
not implemented for cpu!!!
\"
)"
)
return
code
code
.
raw
(
f
"""
for (auto& v : nvrtc_progs){{
const uint8_t* code_ptr = reinterpret_cast<const uint8_t*>(v.second.c_str());
nvrtc_progs_.insert(v.first, tv::NVRTCProgram::from_binary(code_ptr, v.second.size()));
for (auto& d : desps){{
static_key_t static_key = std::make_tuple(d.trans_a(), d.trans_b(), d.trans_c(), d.dtype_a, d.dtype_b,
d.dtype_c, int(d.shuffle_type), d.algo);
auto& vec = static_key_to_desps_[static_key];
vec.push_back(d);
}}
for (auto desp : GemmMain::get_all_algo_desp()){{
prebuilt_names_.insert(desp.__repr__());
}}
"""
)
return
code
@
pccm
.
pybind
.
mark
@
pccm
.
static_function
def
get_available_algo_str_from_arch
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
raw
(
f
"""
std::vector<std::string> res;
"""
)
for
i
in
range
(
len
(
_GEMM_MIN_ARCH_TO_ALGO
)
-
1
,
-
1
,
-
1
):
arch_cur
,
algos
=
_GEMM_MIN_ARCH_TO_ALGO
[
i
]
code
.
raw
(
f
"""
auto arch_cur_
{
i
}
= std::make_tuple(int(
{
arch_cur
[
0
]
}
), int(
{
arch_cur
[
1
]
}
));
"""
)
with
code
.
if_
(
f
"arch >= arch_cur_
{
i
}
"
):
for
algo
in
algos
:
code
.
raw
(
f
"""
res.push_back(
{
pccm
.
literal
(
algo
)
}
);
"""
)
code
.
raw
(
f
"return res;"
)
return
code
.
ret
(
"std::vector<std::string>"
)
@
pccm
.
pybind
.
mark
@
pccm
.
member_function
def
get_all_available
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"a, b, c"
,
"tv::Tensor"
)
code
.
arg
(
"trans_a, trans_b, trans_c"
,
"bool"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"nvrtc_progs"
,
"std::unordered_map<std::string, std::string>"
)
code
.
ctor_init
(
"desps_"
,
"desps"
)
code
.
arg
(
"shuffle_type"
,
"int"
)
code
.
raw
(
f
"""
for (auto& v : nvrtc_progs){{
const uint8_t* code_ptr = reinterpret_cast<const uint8_t*>(v.second.c_str());
nvrtc_progs_.insert(v.first, tv::NVRTCProgram::from_binary(code_ptr, v.second.size()));
if (trans_c){{
trans_a = !trans_a;
trans_b = !trans_b;
std::swap(trans_a, trans_b);
std::swap(a, b);
trans_c = false;
}}
auto avail_algos = get_available_algo_str_from_arch(arch);
std::vector<tv::gemm::GemmAlgoDesp> finally_algos;
auto is_arch_compiled = CompileInfo::arch_is_compiled(arch);
for (auto algo : avail_algos){{
static_key_t static_key = std::make_tuple(trans_a, trans_b, trans_c, int(a.dtype()),
int(b.dtype()), int(c.dtype()), shuffle_type, algo);
if (static_key_to_desps_.find(static_key) == static_key_to_desps_.end()){{
continue;
}}
auto& desps = static_key_to_desps_.at(static_key);
for (auto& desp : desps){{
if (arch >= std::make_tuple(7, 5) && desp.algo ==
{
pccm
.
literal
(
GemmAlgo
.
Volta
.
value
)
}
){{
continue;
}}
auto lda = a.stride(0);
auto ldb = b.stride(0);
auto ldc = c.stride(0);
if (desp.supported_ldx(lda, ldb, ldc)){{
if (!is_arch_compiled){{
auto desp2 = desp;
desp2.is_nvrtc = true;
finally_algos.push_back(desp2);
}}else{{
finally_algos.push_back(desp);
}}
}}
}}
}}
return finally_algos;
"""
)
return
code
return
code
.
ret
(
"std::vector<tv::gemm::GemmAlgoDesp>"
,
pyanno
=
"List[cumm.tensorview.gemm.GemmAlgoDesp]"
)
@
pccm
.
member_function
def
extract_mnk
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"a_shape, b_shape"
,
"tv::TensorShape"
)
code
.
arg
(
"trans_a, trans_b, trans_c"
,
"bool"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"shuffle_type"
,
"int"
)
code
.
arg
(
"a_inds_shape, b_inds_shape, c_inds_shape"
,
"tv::TensorShape"
)
code
.
arg
(
"hint"
,
"int"
,
f
"
{
AlgoHint
.
NoHint
.
value
}
"
)
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"TV_THROW_RT_ERR(
\"
not implemented for cpu!!!
\"
)"
)
return
code
.
ret
(
"std::tuple<int, int, int>"
)
code
.
raw
(
f
"""
std::vector<int64_t> a_shape_vec(a_shape.begin(), a_shape.end());
std::vector<int64_t> b_shape_vec(b_shape.begin(), b_shape.end());
std::vector<int64_t> a_inds_shape_vec(a_inds_shape.begin(), a_inds_shape.end());
std::vector<int64_t> b_inds_shape_vec(b_inds_shape.begin(), b_inds_shape.end());
std::vector<int64_t> c_inds_shape_vec(c_inds_shape.begin(), c_inds_shape.end());
class
ConvGemmOps
(
pccm
.
ParameterizedClass
):
return GemmMain::extract_mnk(a_shape_vec, b_shape_vec, trans_a,
trans_b, trans_c,
shuffle_type,
a_inds_shape_vec, b_inds_shape_vec,
c_inds_shape_vec);
"""
)
return
code
.
ret
(
"std::tuple<int, int, int>"
)
def
__init__
(
self
,
gemm_cu
:
GemmMainUnitTest
,
conv_cu
:
ConvMainUnitTest
):
super
().
__init__
()
self
.
add_dependency
(
ExternalAllocator
,
GemmTuneResult
,
ConvTuneResult
)
self
.
add_param_class
(
"gemm"
,
gemm_cu
,
"GemmMain"
)
self
.
add_param_class
(
"conv"
,
conv_cu
,
"ConvMain"
)
@
pccm
.
static_function
def
extract_mnk_vector
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"a_shape, b_shape"
,
"std::vector<int64_t>"
)
code
.
arg
(
"trans_a, trans_b, trans_c"
,
"bool"
)
code
.
arg
(
"shuffle_type"
,
"int"
)
code
.
arg
(
"a_inds_shape, b_inds_shape, c_inds_shape"
,
"std::vector<int64_t>"
)
code
.
raw
(
f
"""
return GemmMain::extract_mnk(a_shape, b_shape, trans_a,
trans_b, trans_c,
shuffle_type,
a_inds_shape, b_inds_shape,
c_inds_shape);
"""
)
return
code
.
ret
(
"std::tuple<int, int, int>"
)
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
@
pccm
.
member_function
(
virtual
=
True
)
def
cached_get_nvrtc_params
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"desp"
,
"tv::gemm::GemmAlgoDesp"
,
pyanno
=
"cumm.tensorview.gemm.GemmAlgoDesp"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"stream_int"
,
"std::uintptr_t"
)
code
.
raw
(
f
"""
TV_THROW_RT_ERR("not implemented in c++, must be overrided in python!!!");
"""
)
return
code
.
ret
(
"tv::gemm::NVRTCParams"
,
pyanno
=
"cumm.tensorview.gemm.NVRTCParams"
)
@
pccm
.
pybind
.
mark
@
pccm
.
static_function
def
indice_conv
(
self
):
"""1. this function need to take a out features
that from subm first mm.
2. this function don't support CPU.
"""
@
pccm
.
member_function
def
tune_and_cache
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"allocator"
,
"ExternalAllocator&"
)
code
.
arg
(
"
out_features_after_mm
"
,
"tv::Tensor"
)
code
.
arg
(
"
features, filters, indice_pairs"
,
"tv::Tensor
"
)
code
.
arg
(
"
indice_pair_num"
,
"tv::Tensor
"
)
code
.
arg
(
"
num_activate_out
"
,
"int"
)
code
.
arg
(
"
inverse"
,
"bool"
,
"false
"
)
code
.
arg
(
"
subm"
,
"bool"
,
"false
"
)
code
.
arg
(
"al
go
"
,
"
int"
,
f
"
{
ConvAlgo
.
Native
.
value
}
"
)
code
.
arg
(
"
filter_hwio"
,
"bool"
,
"false
"
)
code
.
arg
(
"
a, b, c
"
,
"tv::Tensor"
)
code
.
arg
(
"
trans_a, trans_b, trans_c"
,
"bool
"
)
code
.
arg
(
"
arch"
,
"std::tuple<int, int>
"
)
code
.
arg
(
"
shuffle_type
"
,
"int"
)
code
.
arg
(
"
a_inds, b_inds, c_inds"
,
"tv::Tensor
"
)
code
.
arg
(
"
hint"
,
"int"
,
f
"
{
AlgoHint
.
NoHint
.
value
}
"
)
code
.
arg
(
"al
pha
"
,
"
float"
,
"1.0
"
)
code
.
arg
(
"
beta"
,
"float"
,
"0.0
"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
arg
(
"num_run"
,
"int"
,
"5"
)
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"TV_THROW_RT_ERR(
\"
not implemented for cpu!!!
\"
)"
)
code
.
raw
(
"return std::make_tuple(GemmTuneResult(), -1.0f);"
)
return
code
.
ret
(
"std::tuple<GemmTuneResult, float>"
,
pyanno
=
"Tuple[spconv.core_cc.csrc.sparse.convops.GemmTuneResult, float]"
)
code
.
raw
(
f
"""
throw std::runtime_error("this function can only be used with CUDA.")
TV_ASSERT_RT_ERR(num_run > 1, "error");
auto mnk = extract_mnk(a.shape(), b.shape(), trans_a,
trans_b, trans_c,
arch,
shuffle_type,
a_inds.shape(), b_inds.shape(),
c_inds.shape());
auto m = std::get<0>(mnk);
auto n = std::get<1>(mnk);
auto k = std::get<2>(mnk);
auto avail = get_all_available(a, b, c, trans_a,
trans_b, trans_c, arch, shuffle_type);
auto c_ = c.clone_whole_storage();
std::vector<GemmTuneResult> all_profile_res;
std::vector<int> splitk_tests;
std::vector<float> times;
for (auto& desp : avail){{
tv::gemm::GemmParams params;
if (desp.is_nvrtc && prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end()){{
params.nvrtc_params = cached_get_nvrtc_params(desp, arch, stream_int);
}}
params.a = a;
params.b = b;
params.c = c_;
params.a_inds = a_inds;
params.b_inds = b_inds;
params.c_inds = c_inds;
params.algo_desp = desp;
params.alpha = alpha;
params.beta = beta;
params.stream = stream_int;
if (desp.split_k_serial() && (hint &
{
AlgoHint
.
BackwardWeight
.
value
}
)){{
splitk_tests = {{
{
', '
.
join
(
map
(
str
,
SPCONV_BWD_SPLITK
))
}
}};
}} else {{
splitk_tests = {{1}};
}}
for (auto spk : splitk_tests){{
float total_time = 0.0;
params.split_k_slices = spk;
for (int j = 0; j < num_run; ++j){{
auto ev_start = tv::CUDAEvent();
auto ev_end = tv::CUDAEvent();
ev_start.record(stream_int);
GemmMain::matmul2(params);
ev_end.record(stream_int);
if (j > 0){{
// skip first run
total_time += tv::CUDAEvent::sync_and_duration(ev_start, ev_end);
}}
}}
total_time /= (num_run - 1);
times.push_back(total_time);
all_profile_res.push_back(GemmTuneResult(desp, arch, spk));
}}
}}
TV_ASSERT_RT_ERR(!all_profile_res.empty(), "can't find suitable algorithm");
auto min_idx = std::min_element(times.begin(), times.end()) - times.begin();
auto min_tune_res = all_profile_res[min_idx];
{{
std::lock_guard<std::mutex> guard(mutex_);
algo_cache_key_t key;
if (hint &
{
AlgoHint
.
BackwardWeight
.
value
}
){{
key = std::make_tuple(int(a.dtype()), int(b.dtype()), int(c.dtype()), m, n);
mn_cache_[key] = min_tune_res;
}}
else if (hint &
{
AlgoHint
.
BackwardInput
.
value
}
){{
key = std::make_tuple(int(a.dtype()), int(b.dtype()), int(c.dtype()), n, k);
nk_dgrad_cache_[key] = min_tune_res;
}}
else if (hint &
{
AlgoHint
.
Fowrard
.
value
}
){{
key = std::make_tuple(int(a.dtype()), int(b.dtype()), int(c.dtype()), n, k);
nk_forward_cache_[key] = min_tune_res;
}}
else{{
TV_THROW_RT_ERR("not implemented");
}}
}}
return std::make_tuple(min_tune_res, times[min_idx]);
"""
)
return
code
.
ret
(
"tv::Tensor"
)
return
code
.
ret
(
"std::tuple<GemmTuneResult, float>"
,
pyanno
=
"Tuple[spconv.core_cc.csrc.sparse.convops.GemmTuneResult, float]"
)
@
pccm
.
pybind
.
mark
@
pccm
.
member_function
def
get_tuned_algo
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"a_dtype, b_dtype, c_dtype"
,
"int"
)
code
.
arg
(
"a_shape, b_shape, c_shape"
,
"std::vector<int64_t>"
)
code
.
arg
(
"trans_a, trans_b, trans_c"
,
"bool"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"shuffle_type"
,
"int"
)
code
.
arg
(
"a_inds_shape, b_inds_shape, c_inds_shape"
,
"std::vector<int64_t>"
)
code
.
arg
(
"hint"
,
"int"
,
f
"
{
AlgoHint
.
NoHint
.
value
}
"
)
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"TV_THROW_RT_ERR(
\"
not implemented for cpu!!!
\"
)"
)
code
.
raw
(
"return std::make_tuple(GemmTuneResult(), false);"
)
return
code
.
ret
(
"std::tuple<GemmTuneResult, bool>"
)
code
.
raw
(
f
"""
TV_ASSERT_RT_ERR(!features.is_cpu(), "this function don't support cpu.")
int out_channel;
if (filter_hwio){{
out_channel = filters.dim(-1);
}}else{{
out_channel = filters.dim(-2);
auto mnk = GemmMain::extract_mnk(a_shape, b_shape, trans_a,
trans_b, trans_c,
shuffle_type,
a_inds_shape, b_inds_shape,
c_inds_shape);
auto m = std::get<0>(mnk);
auto n = std::get<1>(mnk);
auto k = std::get<2>(mnk);
GemmTuneResult res;
bool exists = false;
{{
std::lock_guard<std::mutex> guard(mutex_);
algo_cache_key_t key;
if (hint &
{
AlgoHint
.
BackwardWeight
.
value
}
){{
key = std::make_tuple(int(a_dtype), int(b_dtype), int(c_dtype), m, n);
if (mn_cache_.find(key) != mn_cache_.end()){{
res = mn_cache_.at(key);
exists = true;
}}
filters = filters.view(-1, filters.dim(-2), filters.dim(-1));
int kv = filters.dim(0);
int kv_center = kv / 2;
tv::Tensor out_features;
if (kv == 1 && subm){{
return;
}}
auto indice_pair_num_cpu = indice_pair_num.cpu();
auto indice_pair_num_cpu_ptr = indice_pair_num_cpu.data_ptr<int>();
int maxnhot = 0;
bool all_zero = true;
for (int i = 0; i < kv; ++i){{
if (indice_pair_num_cpu_ptr[i] != 0){{
all_zero = false;
maxnhot = std::max(maxnhot, indice_pair_num_cpu_ptr[i]);
else if (hint &
{
AlgoHint
.
BackwardInput
.
value
}
){{
key = std::make_tuple(int(a_dtype), int(b_dtype), int(c_dtype), n, k);
if (nk_dgrad_cache_.find(key) != nk_dgrad_cache_.end()){{
res = nk_dgrad_cache_.at(key);
exists = true;
}}
}}
if (subm && all_zero){{
return;
else if (hint &
{
AlgoHint
.
Fowrard
.
value
}
){{
key = std::make_tuple(int(a_dtype), int(b_dtype), int(c_dtype), n, k);
if (nk_forward_cache_.find(key) != nk_forward_cache_.end()){{
res = nk_forward_cache_.at(key);
exists = true;
}}
bool inited = subm;
auto a = features;
auto c = out_features
;
auto pair_in = indice_pairs[int(inverse)];
auto pair_out = indice_pairs[int(!inverse)];
}}
else{{
TV_THROW_RT_ERR("not implemented")
;
}}
}}
return std::make_tuple(res, exists);
"""
)
return
code
.
ret
(
"std::tuple<GemmTuneResult, bool>"
)
@
pccm
.
pybind
.
mark
@
pccm
.
member_function
def
run_with_tuned_result
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"profile_res"
,
"GemmTuneResult"
)
code
.
arg
(
"a, b, c"
,
"tv::Tensor"
)
code
.
arg
(
"trans_a, trans_b, trans_c"
,
"bool"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
)
code
.
arg
(
"shuffle_type"
,
"int"
)
code
.
arg
(
"a_inds, b_inds, c_inds"
,
"tv::Tensor"
)
code
.
arg
(
"hint"
,
"int"
,
f
"
{
AlgoHint
.
NoHint
.
value
}
"
)
code
.
arg
(
"alpha"
,
"float"
,
"1.0"
)
code
.
arg
(
"beta"
,
"float"
,
"0.0"
)
code
.
arg
(
"workspace"
,
"tv::Tensor"
,
"tv::Tensor()"
,
"cumm.tensorview.Tensor = Tensor()"
)
code
.
arg
(
"timer"
,
"tv::CUDAKernelTimer"
,
"tv::CUDAKernelTimer(false)"
,
"cumm.tensorview.CUDAKernelTimer = CUDAKernelTimer(False)"
)
code
.
arg
(
"force_nvrtc"
,
f
"bool"
,
"false"
)
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"TV_THROW_RT_ERR(
\"
not implemented for cpu!!!
\"
)"
)
return
code
code
.
raw
(
f
"""
auto& desp = profile_res.algo_desp;
int split_k_slices = 1;
if (profile_res.splitk > 1){{
split_k_slices = profile_res.splitk;
}}
tv::gemm::GemmParams params;
bool desp_is_static = prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end();
if (desp.is_nvrtc && (desp_is_static || force_nvrtc)){{
params.nvrtc_params = cached_get_nvrtc_params(desp, profile_res.arch, stream_int);
}}
params.a = a;
params.b = b;
params.c = c;
params.a_inds = a_inds;
params.b_inds = b_inds;
params.c_inds = c_inds;
params.algo_desp = desp;
params.split_k_slices = split_k_slices;
params.stream = stream_int;
params.alpha = alpha;
params.beta = beta;
params.workspace = workspace;
GemmMain::matmul2(params);
"""
)
return
code
class
ConvTunerSimple
(
pccm
.
ParameterizedClass
):
def
__init__
(
self
,
conv_cu
:
Optional
[
ConvMainUnitTest
]
=
None
):
super
().
__init__
()
self
.
add_dependency
(
ExternalAllocator
,
ConvTuneResult
,
TensorView
,
GemmBasicHost
,
CompileInfo
)
if
conv_cu
is
not
None
:
self
.
add_param_class
(
"conv"
,
conv_cu
,
"ConvMain"
)
if
not
CUMM_CPU_ONLY_BUILD
:
assert
conv_cu
is
not
None
self
.
add_include
(
"tensorview/profile/cuda_profiler.h"
)
self
.
add_include
(
"tensorview/utility/tuplehash.h"
)
self
.
add_include
(
"mutex"
)
self
.
add_typedef
(
"static_key_t"
,
(
"std::tuple<int, int, int, int, int, "
"int, int, int, int, std::string, int>"
))
self
.
add_typedef
(
"algo_cache_key_t"
,
"std::tuple<int, int, int, int, "
"int, int, int, int>"
)
self
.
add_member
(
"desps_"
,
"std::vector<tv::gemm::ConvAlgoDesp>"
)
self
.
add_member
(
"static_key_to_desps_"
,
"std::unordered_map<static_key_t, std::vector<tv::gemm::ConvAlgoDesp>>"
)
self
.
add_member
(
"prebuilt_names_"
,
"std::unordered_set<std::string>"
)
self
.
add_member
(
"mutex_"
,
"std::mutex"
)
self
.
add_member
(
"kc_forward_cache_, kc_dgrad_cache_, kc_wgrad_cache_"
,
"std::unordered_map<algo_cache_key_t, ConvTuneResult>"
)
@
pccm
.
pybind
.
mark
@
pccm
.
constructor
def
ctor
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"desps"
,
"std::vector<tv::gemm::ConvAlgoDesp>"
)
code
.
ctor_init
(
"desps_"
,
"desps"
)
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"TV_THROW_RT_ERR(
\"
not implemented for cpu!!!
\"
)"
)
return
code
code
.
raw
(
f
"""
for (auto& d : desps){{
static_key_t static_key = std::make_tuple(
int(d.layout_i), int(d.layout_w), int(d.layout_o),
d.interleave_i, d.interleave_w, d.interleave_o, d.dtype_input(),
d.dtype_weight(), d.dtype_output(), d.algo, int(d.op_type));
auto& vec = static_key_to_desps_[static_key];
vec.push_back(d);
}}
for (auto desp : ConvMain::get_all_conv_algo_desp()){{
prebuilt_names_.insert(desp.__repr__());
}}
"""
)
return
code
@
pccm
.
pybind
.
mark
@
pccm
.
static_function
def
get_available_algo_str_from_arch
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
raw
(
f
"""
std::vector<std::string> res;
"""
)
for
i
in
range
(
len
(
_GEMM_MIN_ARCH_TO_ALGO
)
-
1
,
-
1
,
-
1
):
arch_cur
,
algos
=
_GEMM_MIN_ARCH_TO_ALGO
[
i
]
code
.
raw
(
f
"""
auto arch_cur_
{
i
}
= std::make_tuple(int(
{
arch_cur
[
0
]
}
), int(
{
arch_cur
[
1
]
}
));
"""
)
with
code
.
if_
(
f
"arch >= arch_cur_
{
i
}
"
):
for
algo
in
algos
:
code
.
raw
(
f
"""
res.push_back(
{
pccm
.
literal
(
algo
)
}
);
"""
)
code
.
raw
(
f
"return res;"
)
return
code
.
ret
(
"std::vector<std::string>"
)
@
pccm
.
pybind
.
mark
@
pccm
.
member_function
def
get_all_available
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"inp, weight, out"
,
"tv::Tensor"
)
code
.
arg
(
"layout_i, layout_w, layout_o"
,
"int"
)
code
.
arg
(
"interleave_i, interleave_w, interleave_o"
,
"int"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"op_type"
,
"int"
)
code
.
arg
(
"mask_width"
,
"int"
)
code
.
arg
(
"auto_fp32_accum"
,
"bool"
)
code
.
arg
(
"fp32_accum"
,
"bool"
)
code
.
raw
(
f
"""
tv::gemm::ConvOpType op_type_cpp = static_cast<tv::gemm::ConvOpType>(op_type);
auto avail_algos = get_available_algo_str_from_arch(arch);
bool is_fp16 = (inp.dtype() == tv::float16 &&
weight.dtype() == tv::float16 && out.dtype() == tv::float16);
bool use_f32_as_accum = false;
int kv = 1;
for (int i = 0; i < weight.ndim() - 2; ++i){{
kv *= weight.dim(i + 1);
}}
if (is_fp16){{
if (auto_fp32_accum){{
if (op_type_cpp == tv::gemm::ConvOpType::kForward)
use_f32_as_accum = weight.dim(-1) * kv > 128 * 27;
else if (op_type_cpp == tv::gemm::ConvOpType::kBackwardInput)
use_f32_as_accum = weight.dim(0) * kv > 128 * 27;
}}else{{
use_f32_as_accum = fp32_accum;
}}
}}
use_f32_as_accum = false;
std::vector<tv::gemm::ConvAlgoDesp> finally_algos;
auto is_arch_compiled = CompileInfo::arch_is_compiled(arch);
for (auto algo : avail_algos){{
static_key_t static_key = std::make_tuple(
layout_i, layout_w, layout_o,
interleave_i, interleave_w, interleave_o, inp.dtype(),
weight.dtype(), out.dtype(), algo, op_type);
if (static_key_to_desps_.find(static_key) == static_key_to_desps_.end()){{
continue;
}}
auto& desps = static_key_to_desps_.at(static_key);
for (auto& desp : desps){{
if (arch >= std::make_tuple(7, 5) && desp.algo ==
{
pccm
.
literal
(
GemmAlgo
.
Volta
.
value
)
}
){{
continue;
}}
if (arch >= std::make_tuple(7, 0) && is_fp16){{
// skip simt fp16 kernels if we have tensor core
if (desp.algo ==
{
pccm
.
literal
(
GemmAlgo
.
Simt
.
value
)
}
){{
continue;
}}
if (use_f32_as_accum){{
if (desp.dacc == tv::float16){{
continue;
}}
}}
}}
int ldi = inp.dim(-1);
int ldw = weight.dim(-1);
int ldo = out.dim(-1);
bool mask_width_valid = true;
if (desp.op_type == tv::gemm::ConvOpType::kBackwardWeight){{
TV_ASSERT_RT_ERR(mask_width > 0, "eroro");
mask_width_valid = mask_width % desp.tile_shape[2] == 0;
}}
if (desp.supported_ldx_conv(ldi, ldw, ldo) && mask_width_valid){{
if (!is_arch_compiled){{
auto desp2 = desp;
desp2.is_nvrtc = true;
finally_algos.push_back(desp2);
}}else{{
finally_algos.push_back(desp);
}}
}}
}}
}}
return finally_algos;
"""
)
return
code
.
ret
(
"std::vector<tv::gemm::ConvAlgoDesp>"
,
pyanno
=
"List[cumm.tensorview.gemm.ConvAlgoDesp]"
)
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
@
pccm
.
member_function
(
virtual
=
True
)
def
cached_get_nvrtc_params
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"desp"
,
"tv::gemm::ConvAlgoDesp"
,
pyanno
=
"cumm.tensorview.gemm.ConvAlgoDesp"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"stream_int"
,
"std::uintptr_t"
)
code
.
raw
(
f
"""
TV_THROW_RT_ERR("not implemented in c++, must be overrided in python!!!");
"""
)
return
code
.
ret
(
"tv::gemm::NVRTCParams"
,
pyanno
=
"cumm.tensorview.gemm.NVRTCParams"
)
@
pccm
.
pybind
.
mark
@
pccm
.
member_function
def
tune_and_cache
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"op_type"
,
"int"
)
code
.
arg
(
"inp, weight, output"
,
"tv::Tensor"
)
code
.
arg
(
"layout_i, layout_w, layout_o"
,
"int"
)
code
.
arg
(
"interleave_i, interleave_w, interleave_o"
,
"int"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"mask, mask_argsort, indices"
,
"tv::Tensor"
)
code
.
arg
(
"reverse_mask"
,
"bool"
)
code
.
arg
(
"mask_filter"
,
"uint32_t"
,
"0xffffffff"
)
code
.
arg
(
"mask_width"
,
"int"
,
"-1"
)
code
.
arg
(
"mask_output"
,
"tv::Tensor"
,
"tv::Tensor()"
,
"cumm.tensorview.Tensor = Tensor()"
)
code
.
arg
(
"alpha"
,
"float"
,
"1.0"
)
code
.
arg
(
"beta"
,
"float"
,
"0.0"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
arg
(
"auto_fp32_accum"
,
"bool"
,
"true"
)
code
.
arg
(
"fp32_accum"
,
"bool"
,
"false"
)
code
.
arg
(
"num_run"
,
"int"
,
"5"
)
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"TV_THROW_RT_ERR(
\"
not implemented for cpu!!!
\"
)"
)
return
code
.
ret
(
"std::tuple<ConvTuneResult, float>"
,
pyanno
=
"Tuple[spconv.core_cc.csrc.sparse.convops.ConvTuneResult, float]"
)
code
.
raw
(
f
"""
TV_ASSERT_RT_ERR(num_run > 1, "error");
auto avail = get_all_available(inp, weight, output, layout_i, layout_w,
layout_o, interleave_i, interleave_w, interleave_o,
arch, op_type, mask_width,
auto_fp32_accum, fp32_accum);
inp = inp.clone();
weight = weight.clone();
output = output.clone();
int channel_k = output.dim(1);
int channel_c = inp.dim(1);
std::vector<ConvTuneResult> all_profile_res;
std::vector<int> splitk_tests;
std::vector<float> times;
tv::gemm::ConvOpType op_type_cpp = static_cast<tv::gemm::ConvOpType>(op_type);
for (auto& desp : avail){{
tv::gemm::ConvParams params(
{
NDIM_DONT_CARE
}
, op_type_cpp, tv::CUDAKernelTimer(false));
if (desp.is_nvrtc && prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end()){{
params.nvrtc_params = cached_get_nvrtc_params(desp, arch, stream_int);
}}
params.conv_algo_desp = desp;
params.input = inp;
params.weight = weight.view(channel_k, -1, channel_c);
params.output = output;
params.mask_width = mask_width;
params.alpha = alpha;
params.beta = beta;
params.stream = stream_int;
params.mask_argsort = mask_argsort;
params.indices = indices;
params.mask = mask;
params.mask_output = mask_output;
// if (op_type_cpp == tv::gemm::ConvOpType::kBackwardWeight){{
// TV_ASSERT_RT_ERR(!mask_output.empty(), "error");
// }}
if (op_type_cpp == tv::gemm::ConvOpType::kBackwardInput){{
params.reverse_mask = reverse_mask;
}}
params.mask_filter = mask_filter;
if (desp.split_k_serial() && (op_type_cpp == tv::gemm::ConvOpType::kBackwardWeight)){{
splitk_tests = {{
{
', '
.
join
(
map
(
str
,
SPCONV_BWD_SPLITK
))
}
}};
}} else {{
splitk_tests = {{1}};
}}
for (auto spk : splitk_tests){{
float total_time = 0.0;
params.split_k_slices = spk;
for (int j = 0; j < num_run; ++j){{
auto ev_start = tv::CUDAEvent();
auto ev_end = tv::CUDAEvent();
ev_start.record(stream_int);
ConvMain::implicit_gemm2(params);
ev_end.record(stream_int);
if (j > 0){{
// skip first run
total_time += tv::CUDAEvent::sync_and_duration(ev_start, ev_end);
}}
}}
total_time /= (num_run - 1);
times.push_back(total_time);
all_profile_res.push_back(ConvTuneResult(desp, arch, spk));
}}
}}
TV_ASSERT_RT_ERR(!all_profile_res.empty(), "can't find suitable algorithm for", op_type);
auto min_idx = std::min_element(times.begin(), times.end()) - times.begin();
auto min_tune_res = all_profile_res[min_idx];
if (op_type_cpp != tv::gemm::ConvOpType::kBackwardWeight){{
mask_width = -1;
}}
algo_cache_key_t key;
key = std::make_tuple(int(inp.dtype()), int(weight.dtype()),
int(output.dtype()), channel_k, channel_c, std::get<0>(arch), std::get<1>(arch), mask_width);
{{
std::lock_guard<std::mutex> guard(mutex_);
if (op_type_cpp == tv::gemm::ConvOpType::kForward){{
kc_forward_cache_[key] = min_tune_res;
}}
else if (op_type_cpp == tv::gemm::ConvOpType::kBackwardInput){{
kc_dgrad_cache_[key] = min_tune_res;
}}
else if (op_type_cpp == tv::gemm::ConvOpType::kBackwardWeight){{
kc_wgrad_cache_[key] = min_tune_res;
}}
else{{
TV_THROW_RT_ERR("not implemented");
}}
}}
return std::make_tuple(min_tune_res, times[min_idx]);
"""
)
return
code
.
ret
(
"std::tuple<ConvTuneResult, float>"
,
pyanno
=
"Tuple[spconv.core_cc.csrc.sparse.convops.ConvTuneResult, float]"
)
@
pccm
.
pybind
.
mark
@
pccm
.
member_function
def
get_tuned_algo
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"op_type"
,
"int"
)
code
.
arg
(
"i_dtype, w_dtype, o_dtype"
,
"int"
)
code
.
arg
(
"k, c"
,
"int"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"mask_width"
,
"int"
,
"-1"
)
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"TV_THROW_RT_ERR(
\"
not implemented for cpu!!!
\"
)"
)
return
code
.
ret
(
"std::tuple<ConvTuneResult, bool>"
)
code
.
raw
(
f
"""
tv::gemm::ConvOpType op_type_cpp = static_cast<tv::gemm::ConvOpType>(op_type);
if (op_type_cpp != tv::gemm::ConvOpType::kBackwardWeight){{
mask_width = -1;
}}
algo_cache_key_t key;
key = std::make_tuple(i_dtype, w_dtype, o_dtype, k, c,
std::get<0>(arch), std::get<1>(arch), mask_width);
ConvTuneResult res;
bool exists = false;
{{
std::lock_guard<std::mutex> guard(mutex_);
if (op_type_cpp == tv::gemm::ConvOpType::kForward){{
if (kc_forward_cache_.find(key) != kc_forward_cache_.end()){{
res = kc_forward_cache_.at(key);
exists = true;
}}
}}
else if (op_type_cpp == tv::gemm::ConvOpType::kBackwardInput){{
if (kc_dgrad_cache_.find(key) != kc_dgrad_cache_.end()){{
res = kc_dgrad_cache_.at(key);
exists = true;
}}
}}
else if (op_type_cpp == tv::gemm::ConvOpType::kBackwardWeight){{
if (kc_wgrad_cache_.find(key) != kc_wgrad_cache_.end()){{
res = kc_wgrad_cache_.at(key);
exists = true;
}}
}}
else{{
TV_THROW_RT_ERR("not implemented");
}}
}}
return std::make_tuple(res, exists);
"""
)
return
code
.
ret
(
"std::tuple<ConvTuneResult, bool>"
)
@
pccm
.
pybind
.
mark
@
pccm
.
member_function
def
run_with_tuned_result
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"profile_res"
,
"ConvTuneResult"
)
code
.
arg
(
"op_type"
,
"int"
)
code
.
arg
(
"inp, weight, output"
,
"tv::Tensor"
)
code
.
arg
(
"mask, mask_argsort, mask_output, indices"
,
"tv::Tensor"
)
code
.
arg
(
"reverse_mask"
,
"bool"
)
code
.
arg
(
"mask_filter"
,
"uint32_t"
,
"0xffffffff"
)
code
.
arg
(
"mask_width"
,
"int"
,
"-1"
)
code
.
arg
(
"alpha"
,
"float"
,
"1.0"
)
code
.
arg
(
"beta"
,
"float"
,
"0.0"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
)
code
.
arg
(
"workspace"
,
"tv::Tensor"
,
"tv::Tensor()"
,
"cumm.tensorview.Tensor = Tensor()"
)
code
.
arg
(
"verbose"
,
f
"bool"
,
"false"
)
code
.
arg
(
"timer"
,
"tv::CUDAKernelTimer"
,
"tv::CUDAKernelTimer(false)"
,
"cumm.tensorview.CUDAKernelTimer = CUDAKernelTimer(false)"
)
code
.
arg
(
"force_nvrtc"
,
f
"bool"
,
"false"
)
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"TV_THROW_RT_ERR(
\"
not implemented for cpu!!!
\"
)"
)
return
code
code
.
raw
(
f
"""
auto desp = profile_res.algo_desp;
if (force_nvrtc){{
desp.is_nvrtc = true;
}}
int split_k_slices = 1;
if (profile_res.splitk > 1){{
split_k_slices = profile_res.splitk;
}}
int channel_k = output.dim(1);
int channel_c = inp.dim(1);
tv::gemm::ConvOpType op_type_cpp = static_cast<tv::gemm::ConvOpType>(op_type);
auto arch = profile_res.arch;
tv::gemm::ConvParams params(
{
NDIM_DONT_CARE
}
, op_type_cpp, timer);
bool desp_is_static = prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end();
if (desp.is_nvrtc && (desp_is_static || force_nvrtc)){{
params.nvrtc_params = cached_get_nvrtc_params(desp, arch, stream_int);
}}
params.conv_algo_desp = desp;
params.input = inp;
params.weight = weight.view(channel_k, -1, channel_c);
params.output = output;
params.verbose = verbose;
params.split_k_slices = split_k_slices;
params.alpha = alpha;
params.beta = beta;
params.stream = stream_int;
params.mask_argsort = mask_argsort;
params.indices = indices;
params.mask = mask;
params.mask_filter = mask_filter;
params.mask_width = mask_width;
params.mask_output = mask_output;
params.reverse_mask = reverse_mask;
if (timer.enable()){{
params.timer = timer;
}}
params.workspace = workspace;
ConvMain::implicit_gemm2(params);
"""
)
return
code
@
pccm
.
pybind
.
mark
@
pccm
.
member_function
def
query_workspace_size
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"desp"
,
"tv::gemm::ConvAlgoDesp"
)
code
.
arg
(
"splitk"
,
"int"
)
code
.
arg
(
"op_type, N, C, K, kv"
,
"int"
)
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"TV_THROW_RT_ERR(
\"
not implemented for cpu!!!
\"
)"
)
return
code
.
ret
(
"int"
)
code
.
raw
(
f
'''
auto mnk = ConvMain::extract_mnk(op_type, N, C, K, kv, -1, -1, true);
return desp.query_conv_workspace_size(
std::get<0>(mnk), std::get<1>(mnk), std::get<2>(mnk),
splitk, kv);
'''
)
return
code
.
ret
(
"int"
)
class
ConvGemmOps
(
pccm
.
ParameterizedClass
):
def
__init__
(
self
,
gemm_tuner
:
GemmTunerSimple
,
conv_tuner
:
ConvTunerSimple
):
super
().
__init__
()
self
.
add_dependency
(
ExternalAllocator
,
GemmTuneResult
,
ConvTuneResult
,
ExternalSpconvMatmul
,
)
self
.
add_param_class
(
"gemm"
,
gemm_tuner
,
"GemmTuner"
)
self
.
add_param_class
(
"conv"
,
conv_tuner
,
"ConvTuner"
)
@
pccm
.
pybind
.
mark
@
pccm
.
static_function
def
get_compute_capability
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"index"
,
"int"
,
"-1"
)
code
.
raw
(
f
"""
if (index == -1){{
checkCudaErrors(cudaGetDevice(&index));
}}
#ifdef TV_CUDA
cudaDeviceProp prop;
checkCudaErrors(cudaGetDeviceProperties(&prop, index));
return std::make_tuple(prop.major, prop.minor);
#else
return std::make_tuple(-1, -1);
#endif
"""
)
return
code
.
ret
(
"std::tuple<int, int>"
)
@
pccm
.
pybind
.
mark
@
pccm
.
static_function
def
indice_conv
(
self
):
"""1. this function need to take a out features
that from subm first mm.
2. this function don't support CPU.
"""
code
=
pccm
.
code
()
code
.
add_dependency
(
GatherCPU
)
code
.
arg
(
"allocator"
,
"ExternalAllocator&"
)
code
.
arg
(
"ext_mm"
,
"ExternalSpconvMatmul&"
)
code
.
arg
(
"gemm_tuner"
,
"GemmTuner&"
)
code
.
arg
(
"all_w_is_krsc, filter_hwio"
,
"bool"
)
code
.
arg
(
"features, filters, indice_pairs"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pair_num"
,
"tv::Tensor"
)
code
.
arg
(
"num_activate_out"
,
"int"
)
code
.
arg
(
"inverse"
,
"bool"
,
"false"
)
code
.
arg
(
"subm"
,
"bool"
,
"false"
)
code
.
arg
(
"algo"
,
"int"
,
f
"
{
ConvAlgo
.
Native
.
value
}
"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
raw
(
f
"""
int kv_dim, out_channel, kv;
std::vector<int64_t> filter_shape_per_kv;
bool is_KC_not_CK;
if (!all_w_is_krsc){{
kv_dim = 0;
is_KC_not_CK = !filter_hwio;
if (filter_hwio){{
out_channel = filters.dim(-1);
filter_shape_per_kv = {{filters.dim(-2), out_channel}};
}}else{{
out_channel = filters.dim(-2);
filter_shape_per_kv = {{out_channel, filters.dim(-1)}};
}}
filters = filters.view(-1, filters.dim(-2), filters.dim(-1));
kv = filters.dim(0);
}}else{{
kv_dim = 1;
out_channel = filters.dim(0);
filters = filters.view(out_channel, -1, filters.dim(-1));
is_KC_not_CK = true;
kv = filters.dim(1);
filter_shape_per_kv = {{out_channel, filters.dim(-1)}};
}}
int kv_center = kv / 2;
tv::Tensor out_features;
if (subm){{
out_features = ext_mm.indice_conv_init_gemm(
{
pccm
.
literal
(
AllocKeys
.
Features
)
}
,
{
pccm
.
literal
(
AllocKeys
.
Filters
)
}
, all_w_is_krsc,
is_KC_not_CK, kv_center, out_channel);
}}else{{
out_features = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
OutFeatures
)
}
,
{{num_activate_out, out_channel}}, features.dtype(), features.device());
}}
if (kv == 1 && subm){{
return;
}}
auto indice_pair_num_cpu = indice_pair_num.cpu();
auto indice_pair_num_cpu_ptr = indice_pair_num_cpu.data_ptr<int>();
int maxnhot = 0;
bool all_zero = true;
for (int i = 0; i < kv; ++i){{
if (indice_pair_num_cpu_ptr[i] != 0){{
indice_pair_num_cpu_ptr[i] = std::min(indice_pair_num_cpu_ptr[i], int(indice_pairs.dim(2)));
all_zero = false;
maxnhot = std::max(maxnhot, indice_pair_num_cpu_ptr[i]);
}}
}}
if (subm && all_zero){{
return;
}}
bool inited = subm;
auto a = features;
auto c = out_features;
auto pair_in = indice_pairs[int(inverse)];
auto pair_out = indice_pairs[int(!inverse)];
if (features.is_cpu()){{
TV_ASSERT_RT_ERR(filters.is_cpu() && indice_pairs.is_cpu(), "error");
auto inp_buffer = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
InpBuffer
)
}
,
{{maxnhot, features.dim(1)}}, features.dtype(), -1);
auto out_buffer = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
OutBuffer
)
}
,
{{maxnhot, out_features.dim(1)}}, out_features.dtype(), -1);
for (int i = 0; i < kv; ++i){{
int nhot = indice_pair_num_cpu_ptr[i];
if (subm && i == kv_center){{
continue;
}}
if (subm && i > kv_center){{
nhot = indice_pair_num_cpu_ptr[kv - i - 1];
}}
if (nhot <= 0){{
continue;
}}
auto inp_indices = pair_in[i].slice_first_axis(0, nhot);
auto out_indices = pair_out[i].slice_first_axis(0, nhot);
GatherCPU::gather(inp_buffer, a, inp_indices);
ext_mm.indice_conv_cpu_gemm(
{
pccm
.
literal
(
AllocKeys
.
InpBuffer
)
}
,
{
pccm
.
literal
(
AllocKeys
.
OutBuffer
)
}
,
{
pccm
.
literal
(
AllocKeys
.
Filters
)
}
, all_w_is_krsc,
is_KC_not_CK, nhot, i);
GatherCPU::scatter_add(c, out_buffer, out_indices);
}}
return;
}}
"""
)
if
CUMM_CPU_ONLY_BUILD
:
return
code
code
.
raw
(
f
"""
int profile_idx = kv_center;
if (subm)
profile_idx = kv_center - 1;
int nhot_profile = indice_pair_num_cpu_ptr[profile_idx];
if (nhot_profile == 0){{
profile_idx = 0;
for (int i = 0; i < kv; ++i){{
int nhot = indice_pair_num_cpu_ptr[i];
if (nhot > nhot_profile){{
nhot_profile = nhot;
profile_idx = i;
}}
}}
}}
TV_ASSERT_RT_ERR(nhot_profile > 0, "this shouldn't happen");
auto arch = get_compute_capability();
auto a_shape = a.shape();
auto c_shape = c.shape();
int sac_shuffle_type = static_cast<int>(tv::gemm::ShuffleStrideType::kShuffleAC);
auto tuned_res_exist = gemm_tuner.get_tuned_algo(
int(a.dtype()),
int(filters.dtype()),
int(c.dtype()),
std::vector<int64_t>(a_shape.begin(), a_shape.end()),
filter_shape_per_kv,
std::vector<int64_t>(c_shape.begin(), c_shape.end()),
false,
is_KC_not_CK,
false,
arch,
sac_shuffle_type,
{{nhot_profile}},
{{}},
{{nhot_profile}},
{
AlgoHint
.
Fowrard
.
value
}
);
auto tune_res = std::get<0>(tuned_res_exist);
auto exists = std::get<1>(tuned_res_exist);
if (!exists){{
auto inp_indices = pair_in[profile_idx].slice_first_axis(0, nhot_profile);
auto out_indices = pair_out[profile_idx].slice_first_axis(0, nhot_profile);
auto filter = filters.select(kv_dim, profile_idx);
auto tune_res_time = gemm_tuner.tune_and_cache(
a,
filter,
c,
false,
is_KC_not_CK,
false,
arch,
sac_shuffle_type,
inp_indices,
tv::Tensor(),
out_indices,
{
AlgoHint
.
Fowrard
.
value
}
,
1.0,
0.0,
stream_int);
tune_res = std::get<0>(tune_res_time);
}}
for (int i = 0; i < kv; ++i){{
int nhot = indice_pair_num_cpu_ptr[i];
if (subm && i == kv_center){{
continue;
}}
if (subm && i > kv_center){{
nhot = indice_pair_num_cpu_ptr[kv - i - 1];
}}
if (nhot <= 0){{
continue;
}}
auto inp_indices = pair_in[i].slice_first_axis(0, nhot);
auto out_indices = pair_out[i].slice_first_axis(0, nhot);
auto b = filters.select(kv_dim, i);
float beta = inited ? 1.0 : 0.0;
gemm_tuner.run_with_tuned_result(
tune_res,
a,
b,
c,
false,
is_KC_not_CK,
false,
arch,
stream_int,
sac_shuffle_type,
inp_indices,
tv::Tensor(),
out_indices,
{
AlgoHint
.
Fowrard
.
value
}
,
1.0,
beta);
inited = true;
}}
"""
)
return
code
@
pccm
.
pybind
.
mark
@
pccm
.
static_function
def
indice_conv_backward
(
self
):
code
=
pccm
.
code
()
code
.
add_dependency
(
GatherCPU
)
code
.
arg
(
"allocator"
,
"ExternalAllocator&"
)
code
.
arg
(
"ext_mm"
,
"ExternalSpconvMatmul&"
)
code
.
arg
(
"gemm_tuner"
,
"GemmTuner&"
)
code
.
arg
(
"all_w_is_krsc, filter_hwio"
,
"bool"
)
code
.
arg
(
"features, filters, out_bp, indice_pairs"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pair_num"
,
"tv::Tensor"
)
code
.
arg
(
"inverse"
,
"bool"
,
"false"
)
code
.
arg
(
"subm"
,
"bool"
,
"false"
)
code
.
arg
(
"algo"
,
"int"
,
f
"
{
ConvAlgo
.
Native
.
value
}
"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
raw
(
f
"""
int kv_dim, out_channel, kv;
std::vector<int64_t> filter_shape_per_kv;
auto prev_filter_shape_vec = filters.shape_vector();
bool is_KC_not_CK;
if (!all_w_is_krsc){{
kv_dim = 0;
is_KC_not_CK = !filter_hwio;
if (filter_hwio){{
out_channel = filters.dim(-1);
filter_shape_per_kv = {{filters.dim(-2), out_channel}};
}}else{{
out_channel = filters.dim(-2);
filter_shape_per_kv = {{out_channel, filters.dim(-1)}};
}}
filters = filters.view(-1, filters.dim(-2), filters.dim(-1));
kv = filters.dim(0);
}}else{{
kv_dim = 1;
out_channel = filters.dim(0);
filters = filters.view(out_channel, -1, filters.dim(-1));
is_KC_not_CK = true;
kv = filters.dim(1);
filter_shape_per_kv = {{out_channel, filters.dim(-1)}};
}}
int kv_center = kv / 2;
tv::Tensor din;
auto dfilters = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
DFilters
)
}
,
prev_filter_shape_vec, features.dtype(), features.device());
dfilters = dfilters.view(filters.shape());
if (subm){{
din = ext_mm.indice_conv_bwd_init_gemm(
{
pccm
.
literal
(
AllocKeys
.
Features
)
}
,
{
pccm
.
literal
(
AllocKeys
.
Filters
)
}
,
{
pccm
.
literal
(
AllocKeys
.
OutBp
)
}
,
{
pccm
.
literal
(
AllocKeys
.
DFilters
)
}
,
all_w_is_krsc,
is_KC_not_CK, kv_center);
}}else{{
din = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
DIn
)
}
,
features.shape_vector(), features.dtype(), features.device());
}}
if (kv == 1 && subm){{
return;
}}
auto indice_pair_num_cpu = indice_pair_num.cpu();
auto indice_pair_num_cpu_ptr = indice_pair_num_cpu.data_ptr<int>();
int maxnhot = 0;
bool all_zero = true;
for (int i = 0; i < kv; ++i){{
if (indice_pair_num_cpu_ptr[i] != 0){{
indice_pair_num_cpu_ptr[i] = std::min(indice_pair_num_cpu_ptr[i], int(indice_pairs.dim(2)));
all_zero = false;
maxnhot = std::max(maxnhot, indice_pair_num_cpu_ptr[i]);
}}
}}
if (subm && all_zero){{
return;
}}
bool inited = subm;
auto pair_in = indice_pairs[int(inverse)];
auto pair_out = indice_pairs[int(!inverse)];
if (features.is_cpu()){{
TV_ASSERT_RT_ERR(filters.is_cpu() && indice_pairs.is_cpu(), "error");
auto inp_buffer = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
InpBuffer
)
}
,
{{maxnhot, features.dim(1)}}, features.dtype(), -1);
auto out_buffer = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
OutBuffer
)
}
,
{{maxnhot, out_bp.dim(1)}}, out_bp.dtype(), -1);
for (int i = 0; i < kv; ++i){{
int nhot = indice_pair_num_cpu_ptr[i];
if (subm && i == kv_center){{
continue;
}}
if (subm && i > kv_center){{
nhot = indice_pair_num_cpu_ptr[kv - i - 1];
}}
if (nhot <= 0){{
continue;
}}
auto inp_indices = pair_in[i].slice_first_axis(0, nhot);
auto out_indices = pair_out[i].slice_first_axis(0, nhot);
GatherCPU::gather(inp_buffer, features, inp_indices);
GatherCPU::gather(out_buffer, out_bp, out_indices);
ext_mm.indice_conv_bwd_cpu_gemm(
{
pccm
.
literal
(
AllocKeys
.
InpBuffer
)
}
,
{
pccm
.
literal
(
AllocKeys
.
OutBuffer
)
}
,
{
pccm
.
literal
(
AllocKeys
.
Filters
)
}
,
{
pccm
.
literal
(
AllocKeys
.
DFilters
)
}
, all_w_is_krsc,
is_KC_not_CK, nhot, i);
GatherCPU::scatter_add(din, inp_buffer, inp_indices);
}}
return;
}}
"""
)
if
CUMM_CPU_ONLY_BUILD
:
return
code
code
.
raw
(
f
"""
int profile_idx = kv_center;
if (subm)
profile_idx = kv_center - 1;
int nhot_profile = indice_pair_num_cpu_ptr[profile_idx];
if (nhot_profile == 0){{
profile_idx = 0;
for (int i = 0; i < kv; ++i){{
int nhot = indice_pair_num_cpu_ptr[i];
if (nhot > nhot_profile){{
nhot_profile = nhot;
profile_idx = i;
}}
}}
}}
TV_ASSERT_RT_ERR(nhot_profile > 0, "this shouldn't happen");
auto arch = get_compute_capability();
int sac_shuffle_type = static_cast<int>(tv::gemm::ShuffleStrideType::kShuffleAC);
int sab_shuffle_type = static_cast<int>(tv::gemm::ShuffleStrideType::kShuffleAB);
auto dgrad_tuned_res_exist = gemm_tuner.get_tuned_algo(
int(out_bp.dtype()),
int(filters.dtype()),
int(din.dtype()),
out_bp.shape_vector(),
filter_shape_per_kv,
din.shape_vector(),
false,
!is_KC_not_CK,
false,
arch,
sac_shuffle_type,
{{nhot_profile}},
{{}},
{{nhot_profile}},
{
AlgoHint
.
BackwardInput
.
value
}
);
auto tuned_res_dgrad = std::get<0>(dgrad_tuned_res_exist);
auto dgrad_exists = std::get<1>(dgrad_tuned_res_exist);
if (!dgrad_exists){{
auto inp_indices = pair_in[profile_idx].slice_first_axis(0, nhot_profile);
auto out_indices = pair_out[profile_idx].slice_first_axis(0, nhot_profile);
auto filter = filters.select(kv_dim, profile_idx);
auto tune_res_time = gemm_tuner.tune_and_cache(
out_bp,
filter,
din,
false,
!is_KC_not_CK,
false,
arch,
sac_shuffle_type,
out_indices,
tv::Tensor(),
inp_indices,
{
AlgoHint
.
BackwardInput
.
value
}
,
1.0,
0.0,
stream_int);
tuned_res_dgrad = std::get<0>(tune_res_time);
}}
tv::Tensor a_wgrad, b_wgrad;
if (is_KC_not_CK){{
a_wgrad = out_bp;
b_wgrad = features;
}}
else{{
a_wgrad = features;
b_wgrad = out_bp;
}}
auto wgrad_tuned_res_exist = gemm_tuner.get_tuned_algo(
int(a_wgrad.dtype()),
int(b_wgrad.dtype()),
int(filters.dtype()),
a_wgrad.shape_vector(),
b_wgrad.shape_vector(),
filter_shape_per_kv,
true,
false,
false,
arch,
sab_shuffle_type,
{{nhot_profile}},
{{nhot_profile}},
{{}},
{
AlgoHint
.
BackwardWeight
.
value
}
);
auto tuned_res_wgrad = std::get<0>(wgrad_tuned_res_exist);
auto wgrad_exists = std::get<1>(wgrad_tuned_res_exist);
if (!wgrad_exists){{
auto inp_indices = pair_in[profile_idx].slice_first_axis(0, nhot_profile);
auto out_indices = pair_out[profile_idx].slice_first_axis(0, nhot_profile);
auto dfilter = dfilters.select(kv_dim, profile_idx);
tv::Tensor a_inds_wgrad, b_inds_wgrad;
if (is_KC_not_CK){{
a_inds_wgrad = out_indices;
b_inds_wgrad = inp_indices;
}}else{{
a_inds_wgrad = inp_indices;
b_inds_wgrad = out_indices;
}}
auto tune_res_time = gemm_tuner.tune_and_cache(
a_wgrad,
b_wgrad,
dfilter,
true,
false,
false,
arch,
sab_shuffle_type,
a_inds_wgrad,
b_inds_wgrad,
tv::Tensor(),
{
AlgoHint
.
BackwardWeight
.
value
}
,
1.0,
0.0,
stream_int);
tuned_res_wgrad = std::get<0>(tune_res_time);
}}
std::vector<int64_t> a_shape{{maxnhot, out_bp.dim(1)}};
std::vector<int64_t> b_shape{{maxnhot, features.dim(1)}};
if (!is_KC_not_CK){{
std::swap(a_shape, b_shape);
}}
auto mnk = GemmTuner::extract_mnk_vector(a_shape, b_shape,
tuned_res_wgrad.algo_desp.trans_a(),
tuned_res_wgrad.algo_desp.trans_b(),
tuned_res_wgrad.algo_desp.trans_c(),
sab_shuffle_type,
{{maxnhot}}, {{maxnhot}}, {{}});
auto ws_size = tuned_res_wgrad.algo_desp.query_workspace_size(
std::get<0>(mnk), std::get<1>(mnk), std::get<2>(mnk), tuned_res_wgrad.splitk);
ExternalAllocator::guard_t workspace_guard;
tv::Tensor workspace;
if (ws_size > 0){{
workspace_guard = allocator.empty_guard({{int64_t(ws_size)}}, tv::uint8, 0);
workspace = workspace_guard->tensor;
}}
for (int i = 0; i < kv; ++i){{
int nhot = indice_pair_num_cpu_ptr[i];
if (subm && i == kv_center){{
continue;
}}
if (subm && i > kv_center){{
nhot = indice_pair_num_cpu_ptr[kv - i - 1];
}}
if (nhot <= 0){{
continue;
}}
auto inp_indices = pair_in[i].slice_first_axis(0, nhot);
auto out_indices = pair_out[i].slice_first_axis(0, nhot);
auto filter_i = filters.select(kv_dim, i);
float beta = inited ? 1.0 : 0.0;
gemm_tuner.run_with_tuned_result(
tuned_res_dgrad,
out_bp,
filter_i,
din,
false,
!is_KC_not_CK,
false,
arch,
stream_int,
sac_shuffle_type,
out_indices,
tv::Tensor(),
inp_indices,
{
AlgoHint
.
BackwardInput
.
value
}
,
1.0,
beta);
tv::Tensor a = out_bp;
tv::Tensor b = features;
tv::Tensor a_inds = out_indices;
tv::Tensor b_inds = inp_indices;
if (!is_KC_not_CK){{
std::swap(a, b);
std::swap(a_inds, b_inds);
}}
gemm_tuner.run_with_tuned_result(
tuned_res_wgrad,
a,
b,
dfilters.select(kv_dim, i),
true,
false,
false,
arch,
stream_int,
sab_shuffle_type,
a_inds,
b_inds,
tv::Tensor(),
{
AlgoHint
.
BackwardWeight
.
value
}
,
1.0,
beta);
inited = true;
}}
"""
)
return
code
@
pccm
.
pybind
.
mark
@
pccm
.
static_function
def
implicit_gemm
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"allocator"
,
"ExternalAllocator&"
)
code
.
arg
(
"conv_tuner"
,
"ConvTuner&"
)
code
.
arg
(
"features, filters, pair_fwd"
,
"tv::Tensor"
)
code
.
arg
(
"pair_mask_fwd_splits, mask_argsort_fwd_splits"
,
"std::vector<tv::Tensor>"
)
code
.
arg
(
"num_activate_out"
,
"int"
)
code
.
arg
(
"masks"
,
"tv::Tensor"
)
code
.
arg
(
"is_train, is_subm"
,
"bool"
,
"false"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
arg
(
"timer"
,
"tv::CUDAKernelTimer"
,
"tv::CUDAKernelTimer(false)"
,
"cumm.tensorview.CUDAKernelTimer = CUDAKernelTimer(False)"
)
code
.
arg
(
"auto_fp32_accum"
,
"bool"
,
"true"
)
code
.
arg
(
"fp32_accum"
,
"bool"
,
"false"
)
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"TV_THROW_RT_ERR(
\"
not implemented for cpu!!!
\"
)"
)
return
code
.
ret
(
"int"
)
code
.
raw
(
f
"""
uint32_t* mask_ptr = masks.data_ptr<uint32_t>();
int num_mask = masks.dim(0);
int out_channel = filters.dim(0);
int in_channel = filters.dim(-1);
int num_split = pair_mask_fwd_splits.size();
TV_ASSERT_RT_ERR(num_mask == num_split, "error");
filters = filters.view(out_channel, -1, in_channel);
tv::Tensor out_features;
if (is_subm){{
out_features = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
OutFeatures
)
}
,
{{num_activate_out, out_channel}}, features.dtype(), features.device());
}}else{{
out_features = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
OutFeatures
)
}
,
{{num_activate_out, out_channel}}, features.dtype(), features.device());
}}
auto arch = get_compute_capability();
constexpr auto kForwardInt = static_cast<int>(tv::gemm::ConvOpType::kForward);
constexpr auto kChannelLastInt = static_cast<int>(tv::gemm::ConvLayoutType::kChannelLast);
auto tuned_res_exist = conv_tuner.get_tuned_algo(
kForwardInt,
int(features.dtype()),
int(filters.dtype()),
int(out_features.dtype()),
out_channel, in_channel, arch);
auto tune_res = std::get<0>(tuned_res_exist);
auto exists = std::get<1>(tuned_res_exist);
if (!exists){{
auto tune_res_time = conv_tuner.tune_and_cache(
kForwardInt,
features, filters, out_features,
kChannelLastInt,
kChannelLastInt,
kChannelLastInt,
1, 1, 1,
arch,
pair_mask_fwd_splits[0].type_view(tv::uint32),
mask_argsort_fwd_splits[0],
pair_fwd,
false, // reverse_mask
mask_ptr[0], // mask_filter
-1,
tv::Tensor(), // mask_output
1.0, 0.0,
stream_int,
auto_fp32_accum,
fp32_accum);
tune_res = std::get<0>(tune_res_time);
}}
int mask_width = tune_res.algo_desp.tile_shape[0];
tv::Tensor mask_output_fwd;
std::vector<tv::Tensor> mask_output_fwd_splits;
if (is_train){{
mask_output_fwd = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
MaskOutputFwd
)
}
,
{{num_split, tv::div_up(num_activate_out, mask_width)}},
tv::uint32, features.device());
for (int i = 0; i < num_split; ++i){{
mask_output_fwd_splits.push_back(mask_output_fwd[i]);
}}
}}else{{
for (int i = 0; i < num_split; ++i){{
mask_output_fwd_splits.push_back(tv::Tensor());
}}
}}
for (int j = 0; j < num_split; ++j){{
float beta = j == 0 ? 0 : 1;
conv_tuner.run_with_tuned_result(
tune_res,
kForwardInt,
features,
filters,
out_features,
pair_mask_fwd_splits[j].type_view(tv::uint32),
mask_argsort_fwd_splits[j],
mask_output_fwd_splits[j],
pair_fwd,
false, // reverse_mask
mask_ptr[j],
-1, // mask_width
1.0, beta,
stream_int,
tv::Tensor(), // workspace
false, // verbose
timer);
}}
return mask_width;
"""
)
return
code
.
ret
(
"int"
)
@
pccm
.
pybind
.
mark
@
pccm
.
static_function
def
implicit_gemm_backward
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"allocator"
,
"ExternalAllocator&"
)
code
.
arg
(
"conv_tuner"
,
"ConvTuner&"
)
code
.
arg
(
"features, filters, out_bp, pair_fwd, pair_bwd"
,
"tv::Tensor"
)
code
.
arg
(
"pair_mask_fwd_splits, pair_mask_bwd_splits"
,
"std::vector<tv::Tensor>"
)
code
.
arg
(
"mask_argsort_fwd_splits, mask_argsort_bwd_splits"
,
"std::vector<tv::Tensor>"
)
code
.
arg
(
"mask_output_fwd"
,
"tv::Tensor"
)
code
.
arg
(
"masks"
,
"tv::Tensor"
)
code
.
arg
(
"mask_width"
,
"int"
)
code
.
arg
(
"is_subm"
,
"bool"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
code
.
arg
(
"timer"
,
"tv::CUDAKernelTimer"
,
"tv::CUDAKernelTimer(false)"
,
"cumm.tensorview.CUDAKernelTimer = CUDAKernelTimer(False)"
)
code
.
arg
(
"auto_fp32_accum"
,
"bool"
,
"true"
)
code
.
arg
(
"fp32_accum"
,
"bool"
,
"false"
)
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"TV_THROW_RT_ERR(
\"
not implemented for cpu!!!
\"
)"
)
return
code
code
.
raw
(
f
"""
auto filters_shape = filters.shape();
auto filters_shape_vec = filters.shape_vector();
uint32_t* mask_ptr = masks.data_ptr<uint32_t>();
int num_mask = masks.dim(0);
int out_channel = filters.dim(0);
int in_channel = filters.dim(-1);
int num_split = pair_mask_fwd_splits.size();
TV_ASSERT_RT_ERR(num_mask == num_split, "error");
filters = filters.view(out_channel, -1, in_channel);
int kv = filters.dim(1);
tv::Tensor din;
if (is_subm){{
din = allocator.empty(
{
pccm
.
literal
(
AllocKeys
.
DIn
)
}
,
features.shape_vector(), features.dtype(), features.device());
}}else{{
din = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
DIn
)
}
,
features.shape_vector(), features.dtype(), features.device());
}}
tv::Tensor dfilters = allocator.zeros(
{
pccm
.
literal
(
AllocKeys
.
DFilters
)
}
,
filters_shape_vec, filters.dtype(), filters.device());
dfilters = dfilters.view(out_channel, -1, in_channel);
constexpr auto kForwardInt = static_cast<int>(tv::gemm::ConvOpType::kForward);
constexpr auto kBackwardInputInt = static_cast<int>(tv::gemm::ConvOpType::kBackwardInput);
constexpr auto kBackwardWeightInt = static_cast<int>(tv::gemm::ConvOpType::kBackwardWeight);
constexpr auto kChannelLastInt = static_cast<int>(tv::gemm::ConvLayoutType::kChannelLast);
auto arch = get_compute_capability();
auto dgrad_tuned_res_exist = conv_tuner.get_tuned_algo(
kBackwardInputInt,
int(din.dtype()),
int(filters.dtype()),
int(out_bp.dtype()),
out_channel, in_channel, arch);
auto wgrad_tuned_res_exist = conv_tuner.get_tuned_algo(
kBackwardWeightInt,
int(features.dtype()),
int(dfilters.dtype()),
int(out_bp.dtype()),
out_channel, in_channel, arch, mask_width);
auto dgrad_tune_res = std::get<0>(dgrad_tuned_res_exist);
auto dgrad_exists = std::get<1>(dgrad_tuned_res_exist);
auto wgrad_tune_res = std::get<0>(wgrad_tuned_res_exist);
auto wgrad_exists = std::get<1>(wgrad_tuned_res_exist);
if (!dgrad_exists){{
tv::Tensor mask, mask_argsort;
if (is_subm){{
mask = pair_mask_fwd_splits[0].type_view(tv::uint32);
mask_argsort = mask_argsort_fwd_splits[0];
}}else{{
mask = pair_mask_bwd_splits[0].type_view(tv::uint32);
mask_argsort = mask_argsort_bwd_splits[0];
}}
auto tune_res_time = conv_tuner.tune_and_cache(
kBackwardInputInt,
din, filters, out_bp,
kChannelLastInt,
kChannelLastInt,
kChannelLastInt,
1, 1, 1,
arch,
mask,
mask_argsort,
pair_bwd,
is_subm, // reverse_mask
mask_ptr[0], // mask_filter
-1, // mask width
tv::Tensor(), // mask_output
1.0, 0.0,
stream_int,
auto_fp32_accum,
fp32_accum);
dgrad_tune_res = std::get<0>(tune_res_time);
}}
if (!wgrad_exists){{
auto tune_res_time = conv_tuner.tune_and_cache(
kBackwardWeightInt,
features, dfilters, out_bp,
kChannelLastInt,
kChannelLastInt,
kChannelLastInt,
1, 1, 1,
arch,
mask_output_fwd[0].type_view(tv::uint32),
mask_argsort_fwd_splits[0],
pair_fwd,
false, // reverse_mask
mask_ptr[0], // mask_filter
mask_width,
tv::Tensor(), // mask_output
1.0, 0.0,
stream_int,
auto_fp32_accum,
fp32_accum);
wgrad_tune_res = std::get<0>(tune_res_time);
}}
int ws_size = conv_tuner.query_workspace_size(wgrad_tune_res.algo_desp,
wgrad_tune_res.splitk,
kBackwardWeightInt,
pair_fwd.dim(1), in_channel,
out_channel, kv);
ExternalAllocator::guard_t workspace_guard;
tv::Tensor workspace;
if (ws_size > 0){{
workspace_guard = allocator.empty_guard({{int64_t(ws_size)}}, tv::uint8, 0);
workspace = workspace_guard->tensor;
}}
for (int j = 0; j < num_split; ++j){{
tv::Tensor mask, mask_argsort;
if (is_subm){{
mask = pair_mask_fwd_splits[j].type_view(tv::uint32);
mask_argsort = mask_argsort_fwd_splits[j];
}}else{{
mask = pair_mask_bwd_splits[j].type_view(tv::uint32);
mask_argsort = mask_argsort_bwd_splits[j];
}}
float beta = j == 0 ? 0 : 1;
conv_tuner.run_with_tuned_result(
dgrad_tune_res,
kBackwardInputInt,
din,
filters,
out_bp,
mask,
mask_argsort,
tv::Tensor(), // mask_output
pair_bwd,
is_subm, // reverse_mask
mask_ptr[j],
-1, // mask_width
1.0, beta,
stream_int,
tv::Tensor(), // workspace
false, // verbose
timer);
conv_tuner.run_with_tuned_result(
wgrad_tune_res,
kBackwardWeightInt,
features, dfilters, out_bp,
mask_output_fwd[j].type_view(tv::uint32),
mask_argsort_fwd_splits[j],
tv::Tensor(), // mask_output
pair_fwd,
false, // reverse_mask
mask_ptr[j], // mask_filter
mask_width,
1.0, 0.0,
stream_int,
workspace, // workspace
false, // verbose
timer);
}}
"""
)
return
code
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment