Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
7af751dc
Commit
7af751dc
authored
Jul 12, 2022
by
yan.yan
Browse files
sync
parent
647927ce
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1739 additions
and
608 deletions
+1739
-608
CHANGELOG.md
CHANGELOG.md
+5
-0
example/simple_hash.py
example/simple_hash.py
+1
-1
spconv/algo.py
spconv/algo.py
+197
-38
spconv/algocore.py
spconv/algocore.py
+133
-0
spconv/build.py
spconv/build.py
+9
-10
spconv/constants.py
spconv/constants.py
+20
-0
spconv/core.py
spconv/core.py
+86
-54
spconv/core_cc/csrc/sparse/all/__init__.pyi
spconv/core_cc/csrc/sparse/all/__init__.pyi
+71
-6
spconv/core_cc/csrc/sparse/alloc.pyi
spconv/core_cc/csrc/sparse/alloc.pyi
+54
-0
spconv/core_cc/cumm/__init__.pyi
spconv/core_cc/cumm/__init__.pyi
+14
-0
spconv/core_cc/cumm/conv/main.pyi
spconv/core_cc/cumm/conv/main.pyi
+1
-102
spconv/core_cc/cumm/gemm/main.pyi
spconv/core_cc/cumm/gemm/main.pyi
+3
-140
spconv/csrc/hash/core.py
spconv/csrc/hash/core.py
+35
-32
spconv/csrc/sparse/all.py
spconv/csrc/sparse/all.py
+474
-28
spconv/csrc/sparse/alloc.py
spconv/csrc/sparse/alloc.py
+195
-0
spconv/csrc/sparse/convops.py
spconv/csrc/sparse/convops.py
+207
-0
spconv/csrc/sparse/cpu_core.py
spconv/csrc/sparse/cpu_core.py
+4
-9
spconv/csrc/sparse/indices.py
spconv/csrc/sparse/indices.py
+220
-178
spconv/csrc/sparse/maxpool.py
spconv/csrc/sparse/maxpool.py
+8
-9
spconv/csrc/sparse/pointops.py
spconv/csrc/sparse/pointops.py
+2
-1
No files found.
CHANGELOG.md
View file @
7af751dc
# Changelog
# Changelog
## [2.1.22] - 2022-4-14
### Added
-
add full nvrtc support
-
add support for large spatial shape and batch size. if detect large shape, we use int64 instead of int32 when hashing.
## [2.1.21] - 2021-12-9
## [2.1.21] - 2021-12-9
### Added
### Added
-
add sm_37
-
add sm_37
...
...
example/simple_hash.py
View file @
7af751dc
...
@@ -56,7 +56,7 @@ def main():
...
@@ -56,7 +56,7 @@ def main():
is_empty
=
table
.
insert_exist_keys
(
keys
,
values
)
is_empty
=
table
.
insert_exist_keys
(
keys
,
values
)
ks
,
vs
,
cnt
=
table
.
items
()
ks
,
vs
,
cnt
=
table
.
items
()
cnt_item
=
cnt
.
item
()
cnt_item
=
cnt
.
item
()
print
(
cnt
,
ks
[:
cnt_item
],
vs
[:
cnt_item
])
print
(
cnt
,
ks
[:
cnt_item
],
vs
[:
cnt_item
]
,
is_empty
.
dtype
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
spconv/algo.py
View file @
7af751dc
...
@@ -12,28 +12,50 @@
...
@@ -12,28 +12,50 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
enum
import
Enum
import
contextlib
from
cumm
import
tensorview
as
tv
from
typing
import
Dict
,
List
,
Set
,
Tuple
,
Union
from
spconv.core_cc.cumm.gemm.main
import
GemmAlgoDesp
,
GemmMainUnitTest
,
GemmParams
from
spconv.core_cc.cumm.conv.main
import
ConvAlgoDesp
,
ConvMainUnitTest
,
ConvParams
from
cumm.conv.bases
import
ConvLayout
,
ConvLayoutType
,
ConvOpType
from
cumm.gemm.algospec.core
import
GemmAlgo
,
ShuffleStrideType
,
get_min_arch_of_algo_str
,
get_available_algo_str_from_arch
from
cumm.gemm.codeops
import
group_by
,
div_up
from
spconv.constants
import
NDIM_DONT_CARE
,
SPCONV_BWD_SPLITK
from
typing
import
Optional
import
time
import
time
from
enum
import
Enum
from
threading
import
Lock
from
threading
import
Lock
import
contextlib
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
from
spconv.core
import
ConvAlgo
,
AlgoHint
from
cumm
import
tensorview
as
tv
from
cumm.conv.bases
import
ConvLayout
,
ConvLayoutType
,
ConvOpType
from
cumm.conv.kernel
import
ConvKernel
from
cumm.gemm.kernel
import
GemmKernel
from
cumm.gemm.algospec.core
import
(
GemmAlgo
,
ShuffleStrideType
,
get_available_algo_str_from_arch
,
get_min_arch_of_algo_str
)
from
cumm.gemm.codeops
import
div_up
,
group_by
from
cumm.nvrtc
import
CummNVRTCModule
,
get_cudadevrt_path
from
cumm.tensorview.gemm
import
ConvAlgoDesp
from
cumm.tensorview.gemm
import
ConvOpType
as
ConvOpTypeCpp
from
cumm.tensorview.gemm
import
ConvParams
,
GemmAlgoDesp
,
GemmParams
from
cumm
import
dtypes
from
spconv.constants
import
(
NDIM_DONT_CARE
,
SPCONV_BWD_SPLITK
,
SPCONV_NVRTC_MODE
,
SPCONV_DEBUG_NVRTC_KERNELS
)
from
spconv.core
import
ALL_IMPGEMM_PARAMS
,
AlgoHint
,
ConvAlgo
from
spconv.core_cc.cumm.conv.main
import
ConvMainUnitTest
from
spconv.core_cc.cumm.gemm.main
import
GemmMainUnitTest
from
spconv.cppconstants
import
COMPILED_CUDA_ARCHS
from
cumm.tensorview.gemm
import
NVRTCParams
from
spconv.tools
import
CUDAKernelTimer
from
spconv.tools
import
CUDAKernelTimer
from
cumm.gemm.constants
import
NVRTCConstants
,
NVRTCMode
from
spconv
import
algocore
from
cumm.conv.main
import
gen_gemm_kernels
as
gen_conv_kernels
from
cumm.gemm.main
import
gen_gemm_kernels
ALL_ALGO_DESPS
=
GemmMainUnitTest
.
get_all_algo_desp
()
ALL_ALGO_DESPS
=
GemmMainUnitTest
.
get_all_algo_desp
()
ALL_CONV_ALGO_DESPS
=
ConvMainUnitTest
.
get_all_conv_algo_desp
()
ALL_CONV_ALGO_DESPS
=
ConvMainUnitTest
.
get_all_conv_algo_desp
()
_GEMM_STATIC_KEY
=
Tuple
[
bool
,
bool
,
bool
,
int
,
int
,
int
,
str
,
str
]
_GEMM_STATIC_KEY
=
Tuple
[
bool
,
bool
,
bool
,
int
,
int
,
int
,
str
,
str
]
class
SimpleGemmAlgoMeta
:
class
SimpleGemmAlgoMeta
:
def
__init__
(
self
,
tile_ms
:
List
[
int
],
tile_ns
:
List
[
int
],
def
__init__
(
self
,
tile_ms
:
List
[
int
],
tile_ns
:
List
[
int
],
tile_ks
:
List
[
int
],
tile_ks
:
List
[
int
],
...
@@ -45,22 +67,68 @@ class SimpleGemmAlgoMeta:
...
@@ -45,22 +67,68 @@ class SimpleGemmAlgoMeta:
class
BestAlgoByProfile
:
class
BestAlgoByProfile
:
def
__init__
(
self
,
algo_desp
:
GemmAlgoDesp
,
splitk
:
int
=
1
)
->
None
:
def
__init__
(
self
,
algo_desp
:
GemmAlgoDesp
,
arch
:
Tuple
[
int
,
int
],
splitk
:
int
=
1
)
->
None
:
self
.
algo_desp
=
algo_desp
self
.
algo_desp
=
algo_desp
self
.
splitk
=
splitk
self
.
splitk
=
splitk
self
.
arch
=
arch
class
BestConvAlgoByProfile
:
class
BestConvAlgoByProfile
:
def
__init__
(
self
,
algo_desp
:
ConvAlgoDesp
,
splitk
:
int
=
1
)
->
None
:
def
__init__
(
self
,
algo_desp
:
ConvAlgoDesp
,
arch
:
Tuple
[
int
,
int
],
splitk
:
int
=
1
)
->
None
:
self
.
algo_desp
=
algo_desp
self
.
algo_desp
=
algo_desp
self
.
splitk
=
splitk
self
.
splitk
=
splitk
self
.
arch
=
arch
def
_get_nvrtc_params
(
mod
:
CummNVRTCModule
,
ker
:
Union
[
GemmKernel
,
ConvKernel
],
kernel_name
:
str
):
nvrtc_mode
=
SPCONV_NVRTC_MODE
nvrtc_params
=
tv
.
gemm
.
NVRTCParams
()
nvrtc_params
.
cumodule
=
mod
.
get_cpp_object
()
nvrtc_params
.
mode
=
nvrtc_mode
.
value
nvrtc_params
.
num_threads
=
ker
.
num_threads
nvrtc_params
.
smem_size
=
ker
.
smem_size
ns
=
ker
.
namespace
if
nvrtc_mode
==
NVRTCMode
.
DynamicParallism
:
nvrtc_params
.
kernel_name
=
mod
.
get_lowered_name
(
f
"
{
ns
}
::nvrtc_kernel"
)
elif
nvrtc_mode
==
NVRTCMode
.
KernelAndCPU
:
nvrtc_params
.
kernel_name
=
mod
.
get_lowered_name
(
f
"
{
ns
}
::
{
kernel_name
}
"
)
nvrtc_params
.
init_kernel_name
=
mod
.
get_lowered_name
(
f
"
{
ns
}
::nvrtc_kernel_cpu_out"
)
nvrtc_params
.
param_size
=
mod
.
const_values
[
f
"
{
ns
}
::
{
NVRTCConstants
.
SIZEOF_KEY
}
"
]
nvrtc_params
.
param_storage
=
tv
.
empty
([
nvrtc_params
.
param_size
],
tv
.
uint8
,
0
)
nvrtc_params
.
param_storage_cpu
=
tv
.
empty
(
[
nvrtc_params
.
param_size
],
tv
.
uint8
,
-
1
,
pinned
=
True
)
elif
nvrtc_mode
==
NVRTCMode
.
Direct
:
nvrtc_params
.
kernel_name
=
mod
.
get_lowered_name
(
f
"
{
ns
}
::
{
kernel_name
}
"
)
elif
nvrtc_mode
==
NVRTCMode
.
ConstantMemory
:
nvrtc_params
.
kernel_name
=
mod
.
get_lowered_name
(
f
"
{
ns
}
::
{
kernel_name
}
"
)
nvrtc_params
.
init_kernel_name
=
mod
.
get_lowered_name
(
f
"
{
ns
}
::nvrtc_kernel_cpu_out"
)
nvrtc_params
.
param_size
=
mod
.
const_values
[
f
"
{
ns
}
::
{
NVRTCConstants
.
SIZEOF_KEY
}
"
]
nvrtc_params
.
constant_name
=
mod
.
get_lowered_name
(
f
"&
{
ns
}
::
{
NVRTCConstants
.
CONSTANT_PARAM_KEY
}
"
)
nvrtc_params
.
param_storage
=
tv
.
empty
([
nvrtc_params
.
param_size
],
tv
.
uint8
,
0
)
else
:
raise
NotImplementedError
return
nvrtc_params
class
SimpleGemm
:
class
SimpleGemm
:
def
__init__
(
self
,
desps
:
List
[
GemmAlgoDesp
])
->
None
:
def
__init__
(
self
,
prebuilt_desps
:
List
[
GemmAlgoDesp
])
->
None
:
self
.
desps
=
desps
all_desps
=
[
algocore
.
get_conv_algo_desp_from_param
(
p
)
for
p
in
ALL_IMPGEMM_PARAMS
]
self
.
prebuilt_desps
=
prebuilt_desps
self
.
prebuilt_desp_names
=
{
str
(
d
)
for
d
in
prebuilt_desps
}
if
SPCONV_DEBUG_NVRTC_KERNELS
:
self
.
prebuilt_desp_names
.
clear
()
self
.
lock
=
Lock
()
self
.
lock
=
Lock
()
self
.
static_key_to_desps
=
group_by
(
self
.
get_static_key
,
desps
)
self
.
static_key_to_desps
=
group_by
(
self
.
get_static_key
,
all_
desps
)
self
.
static_key_to_meta
:
Dict
[
_GEMM_STATIC_KEY
,
self
.
static_key_to_meta
:
Dict
[
_GEMM_STATIC_KEY
,
SimpleGemmAlgoMeta
]
=
{}
SimpleGemmAlgoMeta
]
=
{}
for
k
,
static_desps
in
self
.
static_key_to_desps
.
items
():
for
k
,
static_desps
in
self
.
static_key_to_desps
.
items
():
...
@@ -94,15 +162,44 @@ class SimpleGemm:
...
@@ -94,15 +162,44 @@ class SimpleGemm:
self
.
mn_cache
:
Dict
[
Tuple
[
int
,
int
,
int
,
int
,
int
],
self
.
mn_cache
:
Dict
[
Tuple
[
int
,
int
,
int
,
int
,
int
],
BestAlgoByProfile
]
=
{}
# for backward weight
BestAlgoByProfile
]
=
{}
# for backward weight
self
.
_nvrtc_caches
:
Dict
[
Tuple
[
str
,
Tuple
[
int
,
int
]],
NVRTCParams
]
=
{}
@
staticmethod
@
staticmethod
def
get_static_key
(
d
:
GemmAlgoDesp
)
->
_GEMM_STATIC_KEY
:
def
get_static_key
(
d
:
GemmAlgoDesp
)
->
_GEMM_STATIC_KEY
:
return
(
d
.
trans_a
,
d
.
trans_b
,
d
.
trans_c
,
d
.
dtype_a
,
d
.
dtype_b
,
return
(
d
.
trans_a
,
d
.
trans_b
,
d
.
trans_c
,
d
.
dtype_a
,
d
.
dtype_b
,
d
.
dtype_c
,
d
.
shuffle_type
,
d
.
algo
)
d
.
dtype_c
,
d
.
shuffle_type
.
value
,
d
.
algo
)
def
device_synchronize
(
self
):
def
device_synchronize
(
self
):
return
GemmMainUnitTest
.
device_synchronize
()
return
GemmMainUnitTest
.
device_synchronize
()
def
_compile_nvrtc_module
(
self
,
desp
:
GemmAlgoDesp
):
params
=
algocore
.
get_gemm_param_from_desp
(
desp
)
kernel
=
gen_gemm_kernels
(
params
,
SPCONV_NVRTC_MODE
)
kernel
.
namespace
=
"spconv"
custom_names
=
[]
if
SPCONV_NVRTC_MODE
==
NVRTCMode
.
ConstantMemory
:
custom_names
=
[
f
"&
{
kernel
.
namespace
}
::
{
NVRTCConstants
.
CONSTANT_PARAM_KEY
}
"
]
cudadevrt
=
""
if
SPCONV_NVRTC_MODE
==
NVRTCMode
.
DynamicParallism
:
cudadevrt_p
=
get_cudadevrt_path
()
assert
cudadevrt_p
is
not
None
,
"DynamicParallism must have cudadevrt"
cudadevrt
=
str
(
cudadevrt_p
)
mod
=
CummNVRTCModule
([
kernel
],
cudadevrt_path
=
cudadevrt
,
verbose
=
False
,
custom_names
=
custom_names
)
mod
.
load
()
return
mod
,
kernel
def
_cached_get_nvrtc_params
(
self
,
desp
:
GemmAlgoDesp
,
arch
:
Tuple
[
int
,
int
]):
key
=
(
str
(
desp
),
arch
)
if
key
in
self
.
_nvrtc_caches
:
return
self
.
_nvrtc_caches
[
key
]
mod
,
ker
=
self
.
_compile_nvrtc_module
(
desp
)
nvrtc_params
=
_get_nvrtc_params
(
mod
,
ker
,
"gemm_kernel"
)
self
.
_nvrtc_caches
[
key
]
=
nvrtc_params
return
nvrtc_params
def
get_all_available
(
def
get_all_available
(
self
,
self
,
a
:
tv
.
Tensor
,
a
:
tv
.
Tensor
,
...
@@ -135,6 +232,11 @@ class SimpleGemm:
...
@@ -135,6 +232,11 @@ class SimpleGemm:
ldb
=
b
.
dim
(
1
)
ldb
=
b
.
dim
(
1
)
ldc
=
c
.
dim
(
1
)
ldc
=
c
.
dim
(
1
)
if
desp
.
supported_ldx
(
lda
,
ldb
,
ldc
):
if
desp
.
supported_ldx
(
lda
,
ldb
,
ldc
):
if
arch
not
in
COMPILED_CUDA_ARCHS
:
desp
=
desp
.
copy
()
desp
.
is_nvrtc
=
True
if
SPCONV_DEBUG_NVRTC_KERNELS
:
desp
.
is_nvrtc
=
True
finally_algos
.
append
(
desp
)
finally_algos
.
append
(
desp
)
return
finally_algos
return
finally_algos
...
@@ -334,6 +436,8 @@ class SimpleGemm:
...
@@ -334,6 +436,8 @@ class SimpleGemm:
if
desp
.
split_k_serial
and
hint
&
AlgoHint
.
BackwardWeight
.
value
:
if
desp
.
split_k_serial
and
hint
&
AlgoHint
.
BackwardWeight
.
value
:
split_k_slices
=
max
(
min
(
32
,
k
//
128
),
1
)
split_k_slices
=
max
(
min
(
32
,
k
//
128
),
1
)
params
=
GemmParams
()
params
=
GemmParams
()
if
desp
.
is_nvrtc
and
str
(
desp
)
not
in
self
.
prebuilt_desp_names
:
params
.
nvrtc_params
=
self
.
_cached_get_nvrtc_params
(
desp
,
arch
)
params
.
a
=
a
params
.
a
=
a
params
.
b
=
b
params
.
b
=
b
params
.
c
=
c_
params
.
c
=
c_
...
@@ -361,7 +465,7 @@ class SimpleGemm:
...
@@ -361,7 +465,7 @@ class SimpleGemm:
times
.
append
(
np
.
mean
(
this_times
[
1
:]))
times
.
append
(
np
.
mean
(
this_times
[
1
:]))
spk_speeds
.
append
(
times
[
-
1
])
spk_speeds
.
append
(
times
[
-
1
])
all_profile_res
.
append
(
BestAlgoByProfile
(
desp
,
splitk
=
spk
))
all_profile_res
.
append
(
BestAlgoByProfile
(
desp
,
arch
,
splitk
=
spk
))
min_time
=
1000
min_time
=
1000
min_idx
=
-
1
min_idx
=
-
1
...
@@ -421,6 +525,9 @@ class SimpleGemm:
...
@@ -421,6 +525,9 @@ class SimpleGemm:
if
profile_res
.
splitk
>
1
:
if
profile_res
.
splitk
>
1
:
split_k_slices
=
profile_res
.
splitk
split_k_slices
=
profile_res
.
splitk
params
=
GemmParams
()
params
=
GemmParams
()
if
algo_desp
.
is_nvrtc
and
str
(
algo_desp
)
not
in
self
.
prebuilt_desp_names
:
params
.
nvrtc_params
=
self
.
_cached_get_nvrtc_params
(
algo_desp
,
profile_res
.
arch
)
params
.
a
=
a
params
.
a
=
a
params
.
b
=
b
params
.
b
=
b
params
.
c
=
c
params
.
c
=
c
...
@@ -461,11 +568,14 @@ _CONV_STATIC_KEY = Tuple[int, int, int, int, int, int, int, int, int, str, int]
...
@@ -461,11 +568,14 @@ _CONV_STATIC_KEY = Tuple[int, int, int, int, int, int, int, int, int, str, int]
class
SimpleConv
:
class
SimpleConv
:
def
__init__
(
self
,
desps
:
List
[
ConvAlgoDesp
])
->
None
:
def
__init__
(
self
,
prebuilt_desps
:
List
[
ConvAlgoDesp
])
->
None
:
self
.
desps
=
desps
all_desps
=
[
algocore
.
get_conv_algo_desp_from_param
(
p
)
for
p
in
ALL_IMPGEMM_PARAMS
]
self
.
prebuilt_desps
=
prebuilt_desps
self
.
prebuilt_desp_names
=
{
str
(
d
)
for
d
in
prebuilt_desps
}
self
.
prebuilt_desp_names
.
clear
()
self
.
lock
=
Lock
()
self
.
lock
=
Lock
()
self
.
static_key_to_desps
=
group_by
(
self
.
get_static_key
,
desps
)
self
.
static_key_to_desps
=
group_by
(
self
.
get_static_key
,
all_
desps
)
self
.
static_key_to_meta
:
Dict
[
_CONV_STATIC_KEY
,
self
.
static_key_to_meta
:
Dict
[
_CONV_STATIC_KEY
,
SimpleGemmAlgoMeta
]
=
{}
SimpleGemmAlgoMeta
]
=
{}
for
k
,
static_desps
in
self
.
static_key_to_desps
.
items
():
for
k
,
static_desps
in
self
.
static_key_to_desps
.
items
():
...
@@ -500,28 +610,36 @@ class SimpleConv:
...
@@ -500,28 +610,36 @@ class SimpleConv:
self
.
kc_wgrad_cache
:
Dict
[
Tuple
[
int
,
int
,
int
,
int
,
int
,
int
,
int
,
self
.
kc_wgrad_cache
:
Dict
[
Tuple
[
int
,
int
,
int
,
int
,
int
,
int
,
int
,
int
],
BestConvAlgoByProfile
]
=
{
int
],
BestConvAlgoByProfile
]
=
{
}
# for backward weight
}
# for backward weight
self
.
_nvrtc_caches
:
Dict
[
Tuple
[
str
,
Tuple
[
int
,
int
]],
NVRTCParams
]
=
{}
@
staticmethod
@
staticmethod
def
get_static_key
(
d
:
ConvAlgoDesp
)
->
_CONV_STATIC_KEY
:
def
get_static_key
(
d
:
ConvAlgoDesp
)
->
_CONV_STATIC_KEY
:
return
(
d
.
layout_i
,
d
.
layout_w
,
d
.
layout_o
,
d
.
interleave_i
,
return
(
d
.
layout_i
.
value
,
d
.
layout_w
.
value
,
d
.
layout_o
.
value
,
d
.
interleave_w
,
d
.
interleave_o
,
d
.
dtype_input
,
d
.
dtype_weight
,
d
.
interleave_i
,
d
.
interleave_w
,
d
.
interleave_o
,
d
.
dtype_input
,
d
.
dtype_output
,
d
.
algo
,
d
.
op_type
)
d
.
dtype_weight
,
d
.
dtype_output
,
d
.
algo
,
d
.
op_type
.
value
)
def
device_synchronize
(
self
):
def
device_synchronize
(
self
):
return
GemmMainUnitTest
.
device_synchronize
()
return
GemmMainUnitTest
.
device_synchronize
()
def
get_all_available
(
self
,
inp
:
tv
.
Tensor
,
weight
:
tv
.
Tensor
,
def
get_all_available
(
self
,
out
:
tv
.
Tensor
,
layout_i
:
ConvLayout
,
inp
:
tv
.
Tensor
,
layout_w
:
ConvLayout
,
layout_o
:
ConvLayout
,
weight
:
tv
.
Tensor
,
arch
:
Tuple
[
int
,
int
],
op_type
:
ConvOpType
,
out
:
tv
.
Tensor
,
mask_width
:
int
,
fp32_accum
:
Optional
[
bool
]
=
None
):
layout_i
:
ConvLayout
,
layout_w
:
ConvLayout
,
layout_o
:
ConvLayout
,
arch
:
Tuple
[
int
,
int
],
op_type
:
ConvOpType
,
mask_width
:
int
,
fp32_accum
:
Optional
[
bool
]
=
None
):
avail_algos
=
get_available_algo_str_from_arch
(
arch
)
avail_algos
=
get_available_algo_str_from_arch
(
arch
)
finally_algos
:
List
[
ConvAlgoDesp
]
=
[]
finally_algos
:
List
[
ConvAlgoDesp
]
=
[]
is_fp16
=
inp
.
dtype
==
tv
.
float16
and
weight
.
dtype
==
tv
.
float16
and
out
.
dtype
==
tv
.
float16
is_fp16
=
inp
.
dtype
==
tv
.
float16
and
weight
.
dtype
==
tv
.
float16
and
out
.
dtype
==
tv
.
float16
use_f32_as_accum
=
False
use_f32_as_accum
=
False
kv
=
int
(
np
.
prod
(
weight
.
shape
[
1
:
-
1
]))
kv
=
int
(
np
.
prod
(
weight
.
shape
[
1
:
-
1
]))
# for 3d conv, if reduce axis is too large, may cause nan during
# for 3d conv, if reduce axis is too large, may cause nan during
# forward.
# forward.
if
is_fp16
:
if
is_fp16
:
if
fp32_accum
is
None
:
if
fp32_accum
is
None
:
...
@@ -551,7 +669,7 @@ class SimpleConv:
...
@@ -551,7 +669,7 @@ class SimpleConv:
if
use_f32_as_accum
:
if
use_f32_as_accum
:
if
desp
.
dacc
==
tv
.
float16
:
if
desp
.
dacc
==
tv
.
float16
:
continue
continue
ldi
=
inp
.
dim
(
-
1
)
ldi
=
inp
.
dim
(
-
1
)
ldw
=
weight
.
dim
(
-
1
)
ldw
=
weight
.
dim
(
-
1
)
ldo
=
out
.
dim
(
-
1
)
ldo
=
out
.
dim
(
-
1
)
...
@@ -560,6 +678,11 @@ class SimpleConv:
...
@@ -560,6 +678,11 @@ class SimpleConv:
assert
mask_width
>
0
assert
mask_width
>
0
mask_width_valid
=
mask_width
%
desp
.
tile_shape
[
2
]
==
0
mask_width_valid
=
mask_width
%
desp
.
tile_shape
[
2
]
==
0
if
desp
.
supported_ldx_conv
(
ldi
,
ldw
,
ldo
)
and
mask_width_valid
:
if
desp
.
supported_ldx_conv
(
ldi
,
ldw
,
ldo
)
and
mask_width_valid
:
if
arch
not
in
COMPILED_CUDA_ARCHS
:
desp
=
desp
.
copy
()
desp
.
is_nvrtc
=
True
if
SPCONV_DEBUG_NVRTC_KERNELS
:
desp
.
is_nvrtc
=
True
finally_algos
.
append
(
desp
)
finally_algos
.
append
(
desp
)
return
finally_algos
return
finally_algos
...
@@ -592,6 +715,34 @@ class SimpleConv:
...
@@ -592,6 +715,34 @@ class SimpleConv:
return
desp
.
query_conv_workspace_size
(
mnk
[
0
],
mnk
[
1
],
mnk
[
2
],
splitk
,
return
desp
.
query_conv_workspace_size
(
mnk
[
0
],
mnk
[
1
],
mnk
[
2
],
splitk
,
kv
)
kv
)
def
_compile_nvrtc_module
(
self
,
desp
:
ConvAlgoDesp
):
params
=
algocore
.
get_conv_param_from_desp
(
desp
)
kernel
=
gen_conv_kernels
(
params
,
SPCONV_NVRTC_MODE
)
kernel
.
namespace
=
"spconv"
custom_names
=
[]
if
SPCONV_NVRTC_MODE
==
NVRTCMode
.
ConstantMemory
:
custom_names
=
[
f
"&
{
kernel
.
namespace
}
::
{
NVRTCConstants
.
CONSTANT_PARAM_KEY
}
"
]
cudadevrt
=
""
if
SPCONV_NVRTC_MODE
==
NVRTCMode
.
DynamicParallism
:
cudadevrt_p
=
get_cudadevrt_path
()
assert
cudadevrt_p
is
not
None
,
"DynamicParallism must have cudadevrt"
cudadevrt
=
str
(
cudadevrt_p
)
mod
=
CummNVRTCModule
([
kernel
],
cudadevrt_path
=
cudadevrt
,
verbose
=
False
,
custom_names
=
custom_names
)
mod
.
load
()
return
mod
,
kernel
def
_cached_get_nvrtc_params
(
self
,
desp
:
ConvAlgoDesp
,
arch
:
Tuple
[
int
,
int
]):
key
=
(
str
(
desp
),
arch
)
if
key
in
self
.
_nvrtc_caches
:
return
self
.
_nvrtc_caches
[
key
]
mod
,
ker
=
self
.
_compile_nvrtc_module
(
desp
)
nvrtc_params
=
_get_nvrtc_params
(
mod
,
ker
,
"conv_kernel"
)
self
.
_nvrtc_caches
[
key
]
=
nvrtc_params
return
nvrtc_params
def
tune_and_cache
(
self
,
def
tune_and_cache
(
self
,
op_type
:
ConvOpType
,
op_type
:
ConvOpType
,
inp
:
tv
.
Tensor
,
inp
:
tv
.
Tensor
,
...
@@ -613,7 +764,7 @@ class SimpleConv:
...
@@ -613,7 +764,7 @@ class SimpleConv:
stream
:
int
=
0
,
stream
:
int
=
0
,
fp32_accum
:
Optional
[
bool
]
=
None
):
fp32_accum
:
Optional
[
bool
]
=
None
):
avail
=
self
.
get_all_available
(
inp
,
weight
,
output
,
layout_i
,
layout_w
,
avail
=
self
.
get_all_available
(
inp
,
weight
,
output
,
layout_i
,
layout_w
,
layout_o
,
arch
,
op_type
,
mask_width
,
layout_o
,
arch
,
op_type
,
mask_width
,
fp32_accum
)
fp32_accum
)
inp
=
inp
.
clone
()
inp
=
inp
.
clone
()
weight
=
weight
.
clone
()
weight
=
weight
.
clone
()
...
@@ -626,7 +777,10 @@ class SimpleConv:
...
@@ -626,7 +777,10 @@ class SimpleConv:
all_profile_res
:
List
[
BestConvAlgoByProfile
]
=
[]
all_profile_res
:
List
[
BestConvAlgoByProfile
]
=
[]
for
desp
in
avail
:
for
desp
in
avail
:
# for sparse conv, ndim isn't used, so we just provide a constant value.
# for sparse conv, ndim isn't used, so we just provide a constant value.
params
=
ConvParams
(
NDIM_DONT_CARE
,
op_type
.
value
)
params
=
ConvParams
(
NDIM_DONT_CARE
,
ConvOpTypeCpp
(
op_type
.
value
))
if
desp
.
is_nvrtc
and
str
(
desp
)
not
in
self
.
prebuilt_desp_names
:
params
.
nvrtc_params
=
self
.
_cached_get_nvrtc_params
(
desp
,
arch
)
params
.
conv_algo_desp
=
desp
params
.
conv_algo_desp
=
desp
params
.
input
=
inp
params
.
input
=
inp
params
.
weight
=
weight
.
view
([
channel_k
,
-
1
,
channel_c
])
params
.
weight
=
weight
.
view
([
channel_k
,
-
1
,
channel_c
])
...
@@ -657,13 +811,16 @@ class SimpleConv:
...
@@ -657,13 +811,16 @@ class SimpleConv:
GemmMainUnitTest
.
stream_synchronize
(
stream
)
GemmMainUnitTest
.
stream_synchronize
(
stream
)
t
=
time
.
time
()
t
=
time
.
time
()
params
.
split_k_slices
=
spk
params
.
split_k_slices
=
spk
ConvMainUnitTest
.
implicit_gemm2
(
params
)
if
desp
.
is_nvrtc
and
str
(
desp
)
not
in
self
.
prebuilt_desp_names
:
tv
.
gemm
.
run_nvrtc_conv_kernel
(
params
)
else
:
ConvMainUnitTest
.
implicit_gemm2
(
params
)
GemmMainUnitTest
.
stream_synchronize
(
stream
)
GemmMainUnitTest
.
stream_synchronize
(
stream
)
this_times
.
append
(
time
.
time
()
-
t
)
this_times
.
append
(
time
.
time
()
-
t
)
times
.
append
(
np
.
mean
(
this_times
[
1
:]))
times
.
append
(
np
.
mean
(
this_times
[
1
:]))
spk_speeds
.
append
(
times
[
-
1
])
spk_speeds
.
append
(
times
[
-
1
])
all_profile_res
.
append
(
BestConvAlgoByProfile
(
desp
,
splitk
=
spk
))
all_profile_res
.
append
(
BestConvAlgoByProfile
(
desp
,
arch
,
splitk
=
spk
))
if
not
all_profile_res
:
if
not
all_profile_res
:
raise
ValueError
(
"can't find suitable algorithm for"
,
op_type
)
raise
ValueError
(
"can't find suitable algorithm for"
,
op_type
)
min_time
=
1000
min_time
=
1000
...
@@ -720,7 +877,9 @@ class SimpleConv:
...
@@ -720,7 +877,9 @@ class SimpleConv:
op_type_value
=
op_type
op_type_value
=
op_type
else
:
else
:
op_type_value
=
op_type
.
value
op_type_value
=
op_type
.
value
params
=
ConvParams
(
NDIM_DONT_CARE
,
op_type_value
)
params
=
ConvParams
(
NDIM_DONT_CARE
,
ConvOpTypeCpp
(
op_type_value
))
if
algo_desp
.
is_nvrtc
and
str
(
algo_desp
)
not
in
self
.
prebuilt_desp_names
:
params
.
nvrtc_params
=
self
.
_cached_get_nvrtc_params
(
algo_desp
,
profile_res
.
arch
)
params
.
conv_algo_desp
=
profile_res
.
algo_desp
params
.
conv_algo_desp
=
profile_res
.
algo_desp
params
.
input
=
inp
params
.
input
=
inp
params
.
verbose
=
verbose
params
.
verbose
=
verbose
...
...
spconv/algocore.py
0 → 100644
View file @
7af751dc
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
cumm.conv.bases
import
ConvLayout
,
ConvLayoutType
,
ConvOpType
from
cumm.gemm.algospec.core
import
(
GemmAlgo
,
ShuffleStrideType
)
from
cumm.tensorview.gemm
import
ConvAlgoDesp
from
cumm.tensorview.gemm
import
ConvIterAlgo
as
ConvIterAlgoCpp
from
cumm.tensorview.gemm
import
ConvOpType
as
ConvOpTypeCpp
from
cumm.tensorview.gemm
import
ConvLayoutType
as
ConvLayoutTypeCpp
from
cumm.tensorview.gemm
import
ShuffleStrideType
as
ShuffleStrideTypeCpp
from
cumm.tensorview.gemm
import
ConvParams
,
GemmAlgoDesp
,
GemmParams
from
cumm.gemm.main
import
GemmAlgoParams
from
cumm.conv.main
import
ConvAlgoParams
,
ConvIterAlgo
from
cumm
import
dtypes
from
cumm.conv.bases
import
(
NCHW
,
NHWC
,
ConvIterAlgo
,
ConvLayout
,
ConvLayoutType
,
ConvMode
,
ConvOpType
)
from
cumm.gemm.core
import
MetaArray
from
cumm.gemm.algospec
import
TensorOp
def
_assign_gemm_desp_props
(
desp
:
Union
[
ConvAlgoDesp
,
GemmAlgoDesp
],
p
:
Union
[
GemmAlgoParams
,
ConvAlgoParams
]):
desp
.
dtype_a
=
p
.
dtype_a
.
tv_dtype
desp
.
dtype_b
=
p
.
dtype_a
.
tv_dtype
desp
.
dtype_c
=
p
.
dtype_a
.
tv_dtype
desp
.
dacc
=
p
.
dtype_acc
.
tv_dtype
desp
.
dcomp
=
p
.
dtype_comp
.
tv_dtype
desp
.
trans_a
=
p
.
trans_a
desp
.
trans_b
=
p
.
trans_b
desp
.
trans_c
=
p
.
trans_c
desp
.
tile_shape
=
(
p
.
ts
[
0
],
p
.
ts
[
1
],
p
.
ts
[
2
])
desp
.
warp_tile_shape
=
(
p
.
wts
[
0
],
p
.
wts
[
1
],
p
.
wts
[
2
])
if
p
.
tensorop
is
not
None
:
desp
.
tensorop
=
(
p
.
tensorop
[
0
],
p
.
tensorop
[
1
],
p
.
tensorop
[
2
])
desp
.
num_stage
=
p
.
num_stage
desp
.
algo
=
p
.
algo
.
value
desp
.
split_k_serial
=
p
.
splitk_serial
desp
.
split_k_parallel
=
p
.
splitk_parallel
desp
.
shuffle_type
=
ShuffleStrideTypeCpp
(
p
.
shuffle_stride
.
value
)
desp
.
access_per_vector
=
p
.
access_per_vector
desp
.
is_nvrtc
=
p
.
is_nvrtc
def
get_gemm_algo_desp_from_param
(
p
:
GemmAlgoParams
):
desp
=
GemmAlgoDesp
()
_assign_gemm_desp_props
(
desp
,
p
)
return
desp
def
get_conv_algo_desp_from_param
(
p
:
ConvAlgoParams
):
desp
=
ConvAlgoDesp
(
p
.
ndim
,
ConvOpTypeCpp
(
p
.
op_type
.
value
))
_assign_gemm_desp_props
(
desp
,
p
)
# conv attrs
desp
.
ndim
=
p
.
ndim
desp
.
op_type
=
ConvOpTypeCpp
(
p
.
op_type
.
value
)
desp
.
iter_algo
=
ConvIterAlgoCpp
(
p
.
iter_algo
.
value
)
desp
.
layout_i
=
ConvLayoutTypeCpp
(
p
.
layout_desp_input
.
layout_type
.
value
)
desp
.
layout_w
=
ConvLayoutTypeCpp
(
p
.
layout_desp_weight
.
layout_type
.
value
)
desp
.
layout_o
=
ConvLayoutTypeCpp
(
p
.
layout_desp_output
.
layout_type
.
value
)
desp
.
interleave_i
=
p
.
layout_desp_input
.
interleave
desp
.
interleave_w
=
p
.
layout_desp_weight
.
interleave
desp
.
interleave_o
=
p
.
layout_desp_output
.
interleave
desp
.
mask_sparse
=
p
.
mask_sparse
desp
.
increment_k_first
=
p
.
increment_k_first
return
desp
def
_assign_gemm_params
(
desp
:
Union
[
ConvAlgoDesp
,
GemmAlgoDesp
],
p
:
Union
[
GemmAlgoParams
,
ConvAlgoParams
]):
p
.
dtype_a
=
dtypes
.
get_dtype_from_tvdtype
(
desp
.
dtype_a
)
p
.
dtype_b
=
dtypes
.
get_dtype_from_tvdtype
(
desp
.
dtype_b
)
p
.
dtype_c
=
dtypes
.
get_dtype_from_tvdtype
(
desp
.
dtype_c
)
p
.
dtype_acc
=
dtypes
.
get_dtype_from_tvdtype
(
desp
.
dacc
)
p
.
dtype_comp
=
dtypes
.
get_dtype_from_tvdtype
(
desp
.
dcomp
)
p
.
trans_a
=
desp
.
trans_a
p
.
trans_b
=
desp
.
trans_b
p
.
trans_c
=
desp
.
trans_c
p
.
ts
=
MetaArray
(
*
desp
.
tile_shape
)
p
.
wts
=
MetaArray
(
*
desp
.
warp_tile_shape
)
if
desp
.
tensorop
[
0
]
>
0
:
p
.
tensorop
=
TensorOp
(
(
desp
.
tensorop
[
0
],
desp
.
tensorop
[
1
],
desp
.
tensorop
[
2
]))
p
.
num_stage
=
desp
.
num_stage
p
.
algo
=
GemmAlgo
(
desp
.
algo
)
p
.
splitk_serial
=
desp
.
split_k_serial
p
.
splitk_parallel
=
desp
.
split_k_parallel
p
.
shuffle_stride
=
ShuffleStrideType
(
desp
.
shuffle_type
.
value
)
p
.
access_per_vector
=
desp
.
access_per_vector
p
.
is_nvrtc
=
desp
.
is_nvrtc
def
get_gemm_param_from_desp
(
desp
:
GemmAlgoDesp
):
p
=
GemmAlgoParams
((
0
,
0
,
0
),
(
0
,
0
,
0
),
0
,
"s8,s8,s8,s8,s8"
,
False
,
False
,
False
,
GemmAlgo
.
Simt
)
_assign_gemm_params
(
desp
,
p
)
return
p
def
get_conv_param_from_desp
(
desp
:
ConvAlgoDesp
):
p
=
ConvAlgoParams
(
desp
.
ndim
,
ConvOpType
.
kForward
,
ConvIterAlgo
.
Optimized
,
(
0
,
0
,
0
),
(
0
,
0
,
0
),
0
,
"s8,s8,s8,s8,s8"
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Simt
)
_assign_gemm_params
(
desp
,
p
)
# conv attrs
p
.
ndim
=
desp
.
ndim
p
.
op_type
=
ConvOpType
(
desp
.
op_type
.
value
)
p
.
iter_algo
=
ConvIterAlgo
(
desp
.
iter_algo
.
value
)
p
.
layout_desp_input
=
ConvLayout
(
ConvLayoutType
(
desp
.
layout_i
.
value
),
desp
.
interleave_i
)
p
.
layout_desp_weight
=
ConvLayout
(
ConvLayoutType
(
desp
.
layout_w
.
value
),
desp
.
interleave_w
)
p
.
layout_desp_output
=
ConvLayout
(
ConvLayoutType
(
desp
.
layout_o
.
value
),
desp
.
interleave_o
)
p
.
mask_sparse
=
desp
.
mask_sparse
p
.
increment_k_first
=
desp
.
increment_k_first
return
p
spconv/build.py
View file @
7af751dc
...
@@ -29,21 +29,20 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable(
...
@@ -29,21 +29,20 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable(
from
cumm.common
import
CompileInfo
from
cumm.common
import
CompileInfo
from
spconv.csrc.sparse.all
import
SpconvOps
from
spconv.csrc.sparse.all
import
SpconvOps
from
spconv.csrc.sparse.alloc
import
ExternalAllocator
from
spconv.csrc.utils
import
BoxOps
from
spconv.csrc.utils
import
BoxOps
from
spconv.csrc.hash.core
import
HashTable
from
spconv.csrc.hash.core
import
HashTable
all_shuffle
=
SHUFFLE_SIMT_PARAMS
+
SHUFFLE_VOLTA_PARAMS
+
SHUFFLE_TURING_PARAMS
cu
=
GemmMainUnitTest
(
SHUFFLE_SIMT_PARAMS
+
SHUFFLE_VOLTA_PARAMS
+
all_shuffle
=
list
(
filter
(
lambda
x
:
not
x
.
is_nvrtc
,
all_shuffle
))
SHUFFLE_TURING_PARAMS
)
cu
=
GemmMainUnitTest
(
all_shuffle
)
cu
.
namespace
=
"cumm.gemm.main"
cu
.
namespace
=
"cumm.gemm.main"
convcu
=
ConvMainUnitTest
(
IMPLGEMM_SIMT_PARAMS
+
IMPLGEMM_VOLTA_PARAMS
+
all_imp
=
(
IMPLGEMM_SIMT_PARAMS
+
IMPLGEMM_VOLTA_PARAMS
+
IMPLGEMM_TURING_PARAMS
)
IMPLGEMM_TURING_PARAMS
)
all_imp
=
list
(
filter
(
lambda
x
:
not
x
.
is_nvrtc
,
all_imp
))
convcu
=
ConvMainUnitTest
(
all_imp
)
convcu
.
namespace
=
"cumm.conv.main"
convcu
.
namespace
=
"cumm.conv.main"
objects_folder
=
None
pccm
.
builder
.
build_pybind
([
cu
,
convcu
,
SpconvOps
(),
BoxOps
(),
HashTable
(),
CompileInfo
(),
ExternalAllocator
()],
if
InWindows
:
# windows have command line limit, so we use objects_folder to reduce command size.
objects_folder
=
"objects"
pccm
.
builder
.
build_pybind
([
cu
,
convcu
,
SpconvOps
(),
BoxOps
(),
HashTable
(),
CompileInfo
()],
PACKAGE_ROOT
/
"core_cc"
,
PACKAGE_ROOT
/
"core_cc"
,
namespace_root
=
PACKAGE_ROOT
,
namespace_root
=
PACKAGE_ROOT
,
objects_folder
=
objects_folder
,
load_library
=
False
)
load_library
=
False
)
spconv/constants.py
View file @
7af751dc
...
@@ -16,6 +16,8 @@ import os
...
@@ -16,6 +16,8 @@ import os
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
List
from
typing
import
List
from
pccm.utils
import
project_is_editable
,
project_is_installed
from
pccm.utils
import
project_is_editable
,
project_is_installed
from
cumm.gemm.constants
import
NVRTCMode
import
enum
PACKAGE_NAME
=
"spconv"
PACKAGE_NAME
=
"spconv"
PACKAGE_ROOT
=
Path
(
__file__
).
parent
.
resolve
()
PACKAGE_ROOT
=
Path
(
__file__
).
parent
.
resolve
()
...
@@ -43,3 +45,21 @@ else:
...
@@ -43,3 +45,21 @@ else:
# for f16 backward weight, larger splitk, larger compute error.
# for f16 backward weight, larger splitk, larger compute error.
# so we use this env to control maximum splitk.
# so we use this env to control maximum splitk.
SPCONV_BWD_SPLITK
=
list
(
map
(
int
,
os
.
getenv
(
"SPCONV_BWD_SPLITK"
,
"1,2,4,8,16,32,64"
).
split
(
","
)))
SPCONV_BWD_SPLITK
=
list
(
map
(
int
,
os
.
getenv
(
"SPCONV_BWD_SPLITK"
,
"1,2,4,8,16,32,64"
).
split
(
","
)))
SPCONV_NVRTC_MODE
=
NVRTCMode
.
ConstantMemory
SPCONV_DEBUG_NVRTC_KERNELS
=
False
class
SpconvAllocatorKeys
:
Pair
=
"Pair"
IndiceNumPerLoc
=
"IndiceNumPerLoc"
PairMask
=
"PairMask"
MaskArgSort
=
"MaskArgSort"
OutIndices
=
"OutIndices"
PairFwd
=
"PairFwd"
# PairMaskFwd = "PairMaskFwd"
PairMaskBwd
=
"PairMaskBwd"
# MaskArgSortFwd = "MaskArgSortFwd"
MaskArgSortBwd
=
"MaskArgSortBwd"
OutFeatures
=
"OutFeatures"
spconv/core.py
View file @
7af751dc
...
@@ -15,17 +15,17 @@ from enum import Enum
...
@@ -15,17 +15,17 @@ from enum import Enum
from
cumm.gemm.main
import
gen_shuffle_params_v2
as
gen_shuffle_params
,
GemmAlgoParams
from
cumm.gemm.main
import
gen_shuffle_params_v2
as
gen_shuffle_params
,
GemmAlgoParams
from
cumm.gemm
import
kernel
from
cumm.gemm
import
kernel
from
typing
import
List
from
typing
import
List
from
cumm.gemm.algospec.core
import
TensorOp
Params
from
cumm.gemm.algospec.core
import
TensorOp
from
cumm.conv.main
import
gen_gemm_params
as
gen_conv_params
,
ConvFwdAndBwdInput
,
ConvBwdWeight
,
ConvIterAlgo
,
GemmAlgo
from
cumm.conv.main
import
gen_gemm_params
as
gen_conv_params
,
ConvFwdAndBwdInput
,
ConvBwdWeight
,
ConvIterAlgo
,
GemmAlgo
from
cumm.conv.bases
import
(
NCHW
,
NHWC
,
ConvEnum
,
ConvIterAlgo
,
ConvLayout
,
from
cumm.conv.bases
import
(
NCHW
,
NHWC
,
ConvIterAlgo
,
ConvLayout
,
ConvLayoutType
,
ConvMode
,
ConvOpType
)
ConvLayoutType
,
ConvMode
,
ConvOpType
)
from
spconv.constants
import
NDIM_DONT_CARE
from
spconv.constants
import
NDIM_DONT_CARE
class
ConvAlgo
(
Enum
):
class
ConvAlgo
(
Enum
):
Native
=
"Native"
Native
=
0
MaskImplicitGemm
=
"MaskImplicitGemm"
MaskImplicitGemm
=
1
MaskSplitImplicitGemm
=
"MaskSplitImplicitGemm"
MaskSplitImplicitGemm
=
2
class
AlgoHint
(
Enum
):
class
AlgoHint
(
Enum
):
...
@@ -40,17 +40,17 @@ class AlgoHint(Enum):
...
@@ -40,17 +40,17 @@ class AlgoHint(Enum):
SHUFFLE_SIMT_PARAMS
:
List
[
GemmAlgoParams
]
=
[
SHUFFLE_SIMT_PARAMS
:
List
[
GemmAlgoParams
]
=
[
*
gen_shuffle_params
((
64
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
*
gen_shuffle_params
((
64
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
),
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
,
is_nvrtc
=
True
),
*
gen_shuffle_params
((
128
,
64
,
32
),
(
64
,
32
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
*
gen_shuffle_params
((
128
,
64
,
32
),
(
64
,
32
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
),
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
,
is_nvrtc
=
True
),
*
gen_shuffle_params
((
128
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
*
gen_shuffle_params
((
128
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
),
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
,
is_nvrtc
=
True
),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
128
,
128
,
32
),
(
128
,
128
,
32
),
(
64
,
32
,
32
),
[
"s8,s8,s8,s32,s32"
,
"s8,s8,s32,s32,s32"
],
""
,
2
,
(
64
,
32
,
32
),
[
"s8,s8,s8,s32,s32"
,
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
),
kernel
.
GemmAlgo
.
SimtDP4A
,
None
,
is_nvrtc
=
True
),
*
gen_shuffle_params
((
64
,
64
,
32
),
(
32
,
32
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
*
gen_shuffle_params
((
64
,
64
,
32
),
(
32
,
32
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
),
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
,
is_nvrtc
=
True
),
*
gen_shuffle_params
((
64
,
256
,
8
),
(
32
,
64
,
8
),
[
"f32,f32,f32,f32,f32"
],
*
gen_shuffle_params
((
64
,
256
,
8
),
(
32
,
64
,
8
),
[
"f32,f32,f32,f32,f32"
],
"f32,f32,f32,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Simt
,
None
),
"f32,f32,f32,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Simt
,
None
),
# *gen_shuffle_params(
# *gen_shuffle_params(
...
@@ -104,88 +104,88 @@ SHUFFLE_VOLTA_PARAMS: List[GemmAlgoParams] = [
...
@@ -104,88 +104,88 @@ SHUFFLE_VOLTA_PARAMS: List[GemmAlgoParams] = [
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
64
,
64
,
32
),
(
64
,
64
,
32
),
(
32
,
32
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
(
32
,
32
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Volta
,
TensorOp
Params
((
8
,
8
,
4
))),
kernel
.
GemmAlgo
.
Volta
,
TensorOp
((
8
,
8
,
4
))),
# *gen_shuffle_params(
# *gen_shuffle_params(
# (128, 128, 32),
# (128, 128, 32),
# (64, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
# (64, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
# kernel.GemmAlgo.Volta, TensorOp
Params
((8, 8, 4))),
# kernel.GemmAlgo.Volta, TensorOp((8, 8, 4))),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
128
,
256
,
32
),
(
128
,
256
,
32
),
(
64
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
(
64
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Volta
,
TensorOp
Params
((
8
,
8
,
4
))),
kernel
.
GemmAlgo
.
Volta
,
TensorOp
((
8
,
8
,
4
))),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
256
,
128
,
32
),
(
256
,
128
,
32
),
(
64
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
(
64
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Volta
,
TensorOp
Params
((
8
,
8
,
4
))),
kernel
.
GemmAlgo
.
Volta
,
TensorOp
((
8
,
8
,
4
))),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
128
,
64
,
32
),
(
128
,
64
,
32
),
(
64
,
32
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
(
64
,
32
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Volta
,
TensorOp
Params
((
8
,
8
,
4
))),
kernel
.
GemmAlgo
.
Volta
,
TensorOp
((
8
,
8
,
4
))),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
64
,
128
,
32
),
(
64
,
128
,
32
),
(
32
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
(
32
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Volta
,
TensorOp
Params
((
8
,
8
,
4
))),
kernel
.
GemmAlgo
.
Volta
,
TensorOp
((
8
,
8
,
4
))),
]
]
# SHUFFLE_VOLTA_PARAMS = []
# SHUFFLE_VOLTA_PARAMS = []
SHUFFLE_TURING_PARAMS
:
List
[
GemmAlgoParams
]
=
[
SHUFFLE_TURING_PARAMS
:
List
[
GemmAlgoParams
]
=
[
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
64
,
64
,
32
),
(
64
,
64
,
32
),
(
32
,
32
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
(
32
,
32
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
))),
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
8
))),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
128
,
128
,
32
),
(
128
,
128
,
32
),
(
32
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
(
32
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
))),
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
8
))),
# *gen_shuffle_params(
# *gen_shuffle_params(
# (128, 128, 32),
# (128, 128, 32),
# (64, 32, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
# (64, 32, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
# kernel.GemmAlgo.Turing, TensorOp
Params
((16, 8, 8))),
# kernel.GemmAlgo.Turing, TensorOp((16, 8, 8))),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
64
,
64
,
64
),
(
64
,
64
,
64
),
(
32
,
32
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
(
32
,
32
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
))),
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
8
))),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
64
,
128
,
64
),
(
64
,
128
,
64
),
(
32
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
(
32
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
))),
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
8
))),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
128
,
256
,
32
),
(
128
,
256
,
32
),
(
64
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
(
64
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
))),
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
8
))),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
256
,
128
,
32
),
(
256
,
128
,
32
),
(
64
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
(
64
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
))),
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
8
))),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
128
,
64
,
32
),
(
128
,
64
,
32
),
(
64
,
32
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
(
64
,
32
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
))),
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
8
))),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
64
,
128
,
32
),
(
64
,
128
,
32
),
(
32
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
(
32
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
))),
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
8
))),
*
gen_shuffle_params
((
64
,
64
,
32
),
(
32
,
32
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
*
gen_shuffle_params
((
64
,
64
,
32
),
(
32
,
32
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
Params
((
8
,
8
,
16
))),
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))
,
is_nvrtc
=
True
),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
128
,
128
,
32
),
(
128
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
(
32
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
Params
((
8
,
8
,
16
))),
TensorOp
((
8
,
8
,
16
))
,
is_nvrtc
=
True
),
# *gen_shuffle_params(
# *gen_shuffle_params(
# (128, 128, 32),
# (128, 128, 32),
# (64, 32, 32), ["s8,s8,s8,s32,s32", "s8,s8,s32,s32,s32"], "", 2,
# (64, 32, 32), ["s8,s8,s8,s32,s32", "s8,s8,s32,s32,s32"], "", 2,
# kernel.GemmAlgo.Turing, TensorOp
Params
((8, 8, 16))),
# kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
128
,
256
,
32
),
(
128
,
256
,
32
),
(
64
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
(
64
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
Params
((
8
,
8
,
16
))),
TensorOp
((
8
,
8
,
16
))
,
is_nvrtc
=
True
),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
256
,
128
,
32
),
(
256
,
128
,
32
),
(
64
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
(
64
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
Params
((
8
,
8
,
16
))),
TensorOp
((
8
,
8
,
16
))
,
is_nvrtc
=
True
),
*
gen_shuffle_params
((
128
,
64
,
32
),
(
64
,
32
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
*
gen_shuffle_params
((
128
,
64
,
32
),
(
64
,
32
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
Params
((
8
,
8
,
16
))),
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))
,
is_nvrtc
=
True
),
*
gen_shuffle_params
((
64
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
*
gen_shuffle_params
((
64
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
Params
((
8
,
8
,
16
))),
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))
,
is_nvrtc
=
True
),
]
]
# SHUFFLE_TURING_PARAMS = []
# SHUFFLE_TURING_PARAMS = []
...
@@ -399,6 +399,34 @@ IMPLGEMM_SIMT_PARAMS = [
...
@@ -399,6 +399,34 @@ IMPLGEMM_SIMT_PARAMS = [
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
]
]
IMPLGEMM_SIMT_PARAMS
=
[
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
32
,
16
),
(
32
,
32
,
8
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"f32,f32,f32,f32,f32"
,
"f16,f16,f16,f32,f32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Simt
,
None
,
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
*
gen_conv_params
(
ConvBwdWeight
,
(
64
,
32
,
16
),
(
32
,
32
,
8
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"f32,f32,f32,f32,f32"
,
"f16,f16,f16,f32,f32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Simt
,
None
,
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
]
IMPLGEMM_VOLTA_PARAMS
=
[
IMPLGEMM_VOLTA_PARAMS
=
[
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
32
),
(
32
,
32
,
32
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
NDIM_DONT_CARE
,
...
@@ -408,7 +436,7 @@ IMPLGEMM_VOLTA_PARAMS = [
...
@@ -408,7 +436,7 @@ IMPLGEMM_VOLTA_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Volta
,
GemmAlgo
.
Volta
,
TensorOp
Params
((
8
,
8
,
4
)),
TensorOp
((
8
,
8
,
4
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -420,7 +448,7 @@ IMPLGEMM_VOLTA_PARAMS = [
...
@@ -420,7 +448,7 @@ IMPLGEMM_VOLTA_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Volta
,
GemmAlgo
.
Volta
,
TensorOp
Params
((
8
,
8
,
4
)),
TensorOp
((
8
,
8
,
4
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
0
),
access_per_vector
=
0
),
...
@@ -432,7 +460,7 @@ IMPLGEMM_VOLTA_PARAMS = [
...
@@ -432,7 +460,7 @@ IMPLGEMM_VOLTA_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Volta
,
GemmAlgo
.
Volta
,
TensorOp
Params
((
8
,
8
,
4
)),
TensorOp
((
8
,
8
,
4
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -444,7 +472,7 @@ IMPLGEMM_VOLTA_PARAMS = [
...
@@ -444,7 +472,7 @@ IMPLGEMM_VOLTA_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Volta
,
GemmAlgo
.
Volta
,
TensorOp
Params
((
8
,
8
,
4
)),
TensorOp
((
8
,
8
,
4
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -456,7 +484,7 @@ IMPLGEMM_VOLTA_PARAMS = [
...
@@ -456,7 +484,7 @@ IMPLGEMM_VOLTA_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Volta
,
GemmAlgo
.
Volta
,
TensorOp
Params
((
8
,
8
,
4
)),
TensorOp
((
8
,
8
,
4
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -468,7 +496,7 @@ IMPLGEMM_VOLTA_PARAMS = [
...
@@ -468,7 +496,7 @@ IMPLGEMM_VOLTA_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Volta
,
GemmAlgo
.
Volta
,
TensorOp
Params
((
8
,
8
,
4
)),
TensorOp
((
8
,
8
,
4
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
0
),
access_per_vector
=
0
),
...
@@ -480,7 +508,7 @@ IMPLGEMM_VOLTA_PARAMS = [
...
@@ -480,7 +508,7 @@ IMPLGEMM_VOLTA_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Volta
,
GemmAlgo
.
Volta
,
TensorOp
Params
((
8
,
8
,
4
)),
TensorOp
((
8
,
8
,
4
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -495,7 +523,7 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -495,7 +523,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
)),
TensorOp
((
16
,
8
,
8
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
0
),
access_per_vector
=
0
),
...
@@ -507,7 +535,7 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -507,7 +535,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
)),
TensorOp
((
16
,
8
,
8
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -519,7 +547,7 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -519,7 +547,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
)),
TensorOp
((
16
,
8
,
8
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -531,7 +559,7 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -531,7 +559,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
)),
TensorOp
((
16
,
8
,
8
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -543,7 +571,7 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -543,7 +571,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
)),
TensorOp
((
16
,
8
,
8
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -555,7 +583,7 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -555,7 +583,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
)),
TensorOp
((
16
,
8
,
8
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -567,7 +595,7 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -567,7 +595,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
)),
TensorOp
((
16
,
8
,
8
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -579,7 +607,7 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -579,7 +607,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
)),
TensorOp
((
16
,
8
,
8
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -591,7 +619,7 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -591,7 +619,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
)),
TensorOp
((
16
,
8
,
8
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -603,7 +631,7 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -603,7 +631,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
)),
TensorOp
((
16
,
8
,
8
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -615,7 +643,7 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -615,7 +643,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
)),
TensorOp
((
16
,
8
,
8
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -628,7 +656,7 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -628,7 +656,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
)),
TensorOp
((
16
,
8
,
8
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
0
),
access_per_vector
=
0
),
...
@@ -641,7 +669,7 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -641,7 +669,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
)),
TensorOp
((
16
,
8
,
8
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -654,12 +682,16 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -654,12 +682,16 @@ IMPLGEMM_TURING_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
)),
TensorOp
((
16
,
8
,
8
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
# *gen_conv_params(ConvBwdWeight, (32, 64, 32), (32, 32, 16), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, "f16,f16,f16,f32,f32",
# *gen_conv_params(ConvBwdWeight, (32, 64, 32), (32, 32, 16), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, "f16,f16,f16,f32,f32",
# NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOp
Params
((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1),
# NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOp((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1),
# gen_conv_params(ConvFwdAndBwdInput, )
# gen_conv_params(ConvFwdAndBwdInput, )
]
]
ALL_NATIVE_PARAMS
=
SHUFFLE_SIMT_PARAMS
+
SHUFFLE_TURING_PARAMS
+
SHUFFLE_VOLTA_PARAMS
ALL_IMPGEMM_PARAMS
=
IMPLGEMM_SIMT_PARAMS
+
IMPLGEMM_TURING_PARAMS
+
IMPLGEMM_VOLTA_PARAMS
spconv/core_cc/csrc/sparse/all/__init__.pyi
View file @
7af751dc
...
@@ -34,13 +34,15 @@ class SpconvOps:
...
@@ -34,13 +34,15 @@ class SpconvOps:
"""
"""
...
...
@staticmethod
@staticmethod
def generate_conv_inds_stage2(indices: Tensor, hashdata: Tensor, indice_pairs: Tensor, indice_pairs_uniq: Tensor, out_inds: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
def generate_conv_inds_stage2(indices: Tensor, hashdata
_k: Tensor, hashdata_v
: Tensor, indice_pairs: Tensor, indice_pairs_uniq: Tensor,
indice_pairs_uniq_before_sort: Tensor,
out_inds: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
"""
"""
Args:
Args:
indices:
indices:
hashdata:
hashdata_k:
hashdata_v:
indice_pairs:
indice_pairs:
indice_pairs_uniq:
indice_pairs_uniq:
indice_pairs_uniq_before_sort:
out_inds:
out_inds:
num_out_act:
num_out_act:
batch_size:
batch_size:
...
@@ -74,14 +76,16 @@ class SpconvOps:
...
@@ -74,14 +76,16 @@ class SpconvOps:
"""
"""
...
...
@staticmethod
@staticmethod
def generate_conv_inds_mask_stage2(indices: Tensor, hashdata: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
def generate_conv_inds_mask_stage2(indices: Tensor, hashdata
_k: Tensor, hashdata_v
: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor,
indice_pairs_uniq_before_sort: Tensor,
out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
"""
"""
Args:
Args:
indices:
indices:
hashdata:
hashdata_k:
hashdata_v:
indice_pairs_fwd:
indice_pairs_fwd:
indice_pairs_bwd:
indice_pairs_bwd:
indice_pairs_uniq:
indice_pairs_uniq:
indice_pairs_uniq_before_sort:
out_inds:
out_inds:
mask_fwd:
mask_fwd:
mask_bwd:
mask_bwd:
...
@@ -98,11 +102,12 @@ class SpconvOps:
...
@@ -98,11 +102,12 @@ class SpconvOps:
"""
"""
...
...
@staticmethod
@staticmethod
def generate_subm_conv_inds(indices: Tensor, hashdata: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0) -> int:
def generate_subm_conv_inds(indices: Tensor, hashdata
_k: Tensor, hashdata_v
: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0) -> int:
"""
"""
Args:
Args:
indices:
indices:
hashdata:
hashdata_k:
hashdata_v:
indice_pairs:
indice_pairs:
out_inds:
out_inds:
indice_num_per_loc:
indice_num_per_loc:
...
@@ -276,6 +281,18 @@ class SpconvOps:
...
@@ -276,6 +281,18 @@ class SpconvOps:
"""
"""
...
...
@staticmethod
@staticmethod
def sort_1d_by_key_split_allocator_v2(data: Tensor, allocator, mask: Tensor, indices: Tensor = Tensor(), stream: int = 0, mask_output: bool = False) -> Tensor:
"""
Args:
data:
allocator:
mask:
indices:
stream:
mask_output:
"""
...
@staticmethod
def count_bits(a: Tensor) -> Tensor:
def count_bits(a: Tensor) -> Tensor:
"""
"""
Args:
Args:
...
@@ -328,3 +345,51 @@ class SpconvOps:
...
@@ -328,3 +345,51 @@ class SpconvOps:
stream_int:
stream_int:
"""
"""
...
...
@staticmethod
def get_int32_max() -> int: ...
@staticmethod
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0) -> Tensor:
"""
Args:
allocator:
indices:
batch_size:
input_dims:
algo:
ksize:
stride:
padding:
dilation:
out_padding:
subm:
transposed:
is_train:
stream_int:
"""
...
@staticmethod
def get_indice_pairs(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, stream_int: int = 0) -> None:
"""
Args:
allocator:
indices:
batch_size:
input_dims:
algo:
ksize:
stride:
padding:
dilation:
out_padding:
subm:
transposed:
stream_int:
"""
...
@staticmethod
def test_allocator(allocator) -> None:
"""
Args:
allocator:
"""
...
spconv/core_cc/csrc/sparse/alloc.pyi
0 → 100644
View file @
7af751dc
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview import Tensor
class ExternalAllocator:
def zeros(self, name: str, shape: List[int], dtype: int, device: int) -> Tensor:
"""
Args:
name:
shape:
dtype:
device:
"""
...
def empty(self, name: str, shape: List[int], dtype: int, device: int) -> Tensor:
"""
Args:
name:
shape:
dtype:
device:
"""
...
def full_int(self, name: str, shape: List[int], value: int, dtype: int, device: int) -> Tensor:
"""
Args:
name:
shape:
value:
dtype:
device:
"""
...
def full_float(self, name: str, shape: List[int], value: float, dtype: int, device: int) -> Tensor:
"""
Args:
name:
shape:
value:
dtype:
device:
"""
...
def free(self, ten: Tensor) -> None:
"""
Args:
ten:
"""
...
def free_noexcept(self, ten: Tensor) -> None:
"""
Args:
ten:
"""
...
spconv/core_cc/cumm/__init__.pyi
View file @
7af751dc
# Copyright 2022 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
spconv/core_cc/cumm/conv/main.pyi
View file @
7af751dc
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue
from ...cumm.gemm.main import GemmAlgoDesp
from cumm.tensorview.gemm import ConvParams
from cumm.tensorview import Tensor
from cumm.tensorview import CUDAKernelTimer
class ConvAlgoDesp(GemmAlgoDesp):
ndim: int
op_type: int
iter_algo: int
layout_i: int
layout_w: int
layout_o: int
interleave_i: int
interleave_w: int
interleave_o: int
mask_sparse: bool
increment_k_first: bool
def __init__(self, ndim: int, op_type: int) -> None:
"""
Args:
ndim:
op_type:
"""
...
def __repr__(self) -> str: ...
@staticmethod
def conv_iwo_012_to_abc(op_type: int) -> List[int]:
"""
Args:
op_type:
"""
...
@staticmethod
def gemm_abc_012_to_iwo(op_type: int) -> List[int]:
"""
Args:
op_type:
"""
...
@property
def dtype_input(self) -> int: ...
@property
def dtype_weight(self) -> int: ...
@property
def dtype_output(self) -> int: ...
def supported(self, m: int, n: int, k: int, C: int, K: int, mask_width: int) -> bool:
"""
Args:
m:
n:
k:
C:
K:
mask_width:
"""
...
def query_conv_workspace_size(self, m: int, n: int, k: int, split_k_slices: int, kv: int) -> int:
"""
Args:
m:
n:
k:
split_k_slices:
kv:
"""
...
def supported_ldx_conv(self, ldi: int, ldw: int, ldo: int) -> bool:
"""
Args:
ldi:
ldw:
ldo:
"""
...
class ConvParams:
conv_algo_desp: Any
input: Tensor
weight: Tensor
output: Tensor
split_k_slices: int
padding: List[int]
stride: List[int]
dilation: List[int]
alpha: float
beta: float
mask_width: int
mask_filter: int
reverse_mask: bool
verbose: bool
timer: CUDAKernelTimer
workspace: Tensor = Tensor()
mask: Tensor = Tensor()
mask_argsort: Tensor = Tensor()
indices: Tensor = Tensor()
mask_output: Tensor = Tensor()
stream: int
def __init__(self, ndim: int, op_type: int, timer: CUDAKernelTimer = CUDAKernelTimer(False)) -> None:
"""
Args:
ndim:
op_type:
timer:
"""
...
class ConvMainUnitTest:
class ConvMainUnitTest:
@staticmethod
@staticmethod
def extract_mnk(op_type: int, N: int, C: int, K: int, kernel_volume: int, in_prod: int, out_prod: int, mask_sparse: bool) -> List[int]:
def extract_mnk(op_type: int, N: int, C: int, K: int, kernel_volume: int, in_prod: int, out_prod: int, mask_sparse: bool) -> List[int]:
...
...
spconv/core_cc/cumm/gemm/main.pyi
View file @
7af751dc
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview import Tensor
from cumm.tensorview.gemm import GemmParams
from cumm.tensorview import CUDAKernelTimer
class GemmAlgoDesp:
dtype_a: int
dtype_b: int
dtype_c: int
tile_shape: Tuple[int, int, int]
warp_tile_shape: Tuple[int, int, int]
num_stage: int
dacc: int
dcomp: int
algo: str
tensorop: List[int]
split_k_serial_: int
split_k_parallel_: int
shuffle_type: str
element_per_access_a: int
element_per_access_b: int
element_per_access_c: int
access_per_vector: int
def __init__(self) -> None: ...
def __repr__(self) -> str: ...
@property
def split_k_serial(self) -> bool: ...
@split_k_serial.setter
def split_k_serial(self, val: bool) -> None:
"""
Args:
val:
"""
...
@property
def split_k_parallel(self) -> bool: ...
@split_k_parallel.setter
def split_k_parallel(self, val: bool) -> None:
"""
Args:
val:
"""
...
def check_valid(self) -> None: ...
@property
def trans_a(self) -> bool: ...
@trans_a.setter
def trans_a(self, val: bool) -> None:
"""
Args:
val:
"""
...
@property
def trans_b(self) -> bool: ...
@trans_b.setter
def trans_b(self, val: bool) -> None:
"""
Args:
val:
"""
...
@property
def trans_c(self) -> bool: ...
@trans_c.setter
def trans_c(self, val: bool) -> None:
"""
Args:
val:
"""
...
def query_workspace_size(self, m: int, n: int, k: int, split_k_slices: int) -> int:
"""
Args:
m:
n:
k:
split_k_slices:
"""
...
def supported(self, m: int, n: int, k: int) -> bool:
"""
Args:
m:
n:
k:
"""
...
def supported_ldx(self, lda: int, ldb: int, ldc: int) -> bool:
"""
Args:
lda:
ldb:
ldc:
"""
...
class GemmParams:
algo_desp: GemmAlgoDesp
split_k_slices: int
workspace: Tensor = Tensor()
a_inds: Tensor = Tensor()
b_inds: Tensor = Tensor()
c_inds: Tensor = Tensor()
alpha: float
beta: float
stream: int
timer: CUDAKernelTimer
def __init__(self, timer: CUDAKernelTimer = CUDAKernelTimer(False)) -> None:
"""
Args:
timer:
"""
...
def check_valid(self) -> None: ...
@property
def a(self) -> Tensor: ...
@a.setter
def a(self, val: Tensor) -> None:
"""
Args:
val:
"""
...
@property
def b(self) -> Tensor: ...
@b.setter
def b(self, val: Tensor) -> None:
"""
Args:
val:
"""
...
@property
def c(self) -> Tensor: ...
@c.setter
def c(self, val: Tensor) -> None:
"""
Args:
val:
"""
...
class GemmMainUnitTest:
class GemmMainUnitTest:
@staticmethod
@staticmethod
def get_all_algo_desp() -> List[
GemmAlgoDesp
]: ...
def get_all_algo_desp() -> List[
Any
]: ...
@staticmethod
@staticmethod
def extract_mnk(a_shape: List[int], b_shape: List[int], trans_a: bool, trans_b: bool, trans_c: bool, shuffle_type: str = "
NS
", a_inds_shape: List[int] = [], b_inds_shape: List[int] = [], c_inds_shape: List[int] = []) -> Tuple[int, int, int]:
def extract_mnk(a_shape: List[int], b_shape: List[int], trans_a: bool, trans_b: bool, trans_c: bool, shuffle_type: str = "
0
", a_inds_shape: List[int] = [], b_inds_shape: List[int] = [], c_inds_shape: List[int] = []) -> Tuple[int, int, int]:
"""
"""
Args:
Args:
a_shape:
a_shape:
...
...
spconv/csrc/hash/core.py
View file @
7af751dc
...
@@ -104,6 +104,8 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -104,6 +104,8 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
self
.
add_member
(
"map_8_8"
,
"tsl::robin_map<uint64_t, uint64_t>"
)
self
.
add_member
(
"map_8_8"
,
"tsl::robin_map<uint64_t, uint64_t>"
)
self
.
add_pybind_member
(
"insert_count_"
,
"int64_t"
,
prop_name
=
"insert_count"
,
readwrite
=
False
)
self
.
add_pybind_member
(
"insert_count_"
,
"int64_t"
,
prop_name
=
"insert_count"
,
readwrite
=
False
)
self
.
valid_hash_key_types
=
[
dtypes
.
int32
,
dtypes
.
int64
,
dtypes
.
uint32
,
dtypes
.
uint64
]
@
pccm
.
pybind
.
mark
@
pccm
.
pybind
.
mark
@
pccm
.
constructor
@
pccm
.
constructor
def
ctor
(
self
):
def
ctor
(
self
):
...
@@ -163,11 +165,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -163,11 +165,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
auto custream = reinterpret_cast<cudaStream_t>(stream);
auto custream = reinterpret_cast<cudaStream_t>(stream);
"""
)
"""
)
for
k_items
in
_dispatch
_ints
(
code
,
[
4
,
8
]
,
"keys_data.
itemsiz
e()"
):
for
k_items
in
_dispatch
(
code
,
self
.
valid_hash_key_types
,
"keys_data.
dtyp
e()"
):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
using K = tv::hash::itemsize_to_unsigned_t<
{
k_items
}
>;
using K =
{
k_items
}
;
constexpr K kEmptyKey = std::numeric_limits<K>::max();
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
"""
)
"""
)
for
v_items
in
_dispatch_ints
(
code
,
[
4
,
8
],
"values_data.itemsize()"
):
for
v_items
in
_dispatch_ints
(
code
,
[
4
,
8
],
"values_data.itemsize()"
):
...
@@ -176,10 +176,10 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -176,10 +176,10 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
V* value_data_ptr = reinterpret_cast<V*>(values_data.raw_data());
V* value_data_ptr = reinterpret_cast<V*>(values_data.raw_data());
using table_t =
using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
kEmptyKey
, false>;
tv::hash::default_empty_key_v<K>
, false>;
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
tv::cuda::Launch launcher(table.size(), custream);
tv::cuda::Launch launcher(table.size(), custream);
launcher(tv::hash::clear_
table
_split<table_t>, table);
launcher(tv::hash::clear_
map_kernel
_split<table_t>, table);
"""
)
"""
)
return
code
return
code
...
@@ -201,9 +201,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -201,9 +201,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
int64_t value_after_insert = keys.dim(0) + insert_count_;
int64_t value_after_insert = keys.dim(0) + insert_count_;
TV_ASSERT_RT_ERR(value_after_insert < keys_data.dim(0), "inserted count exceed maximum hash size");
TV_ASSERT_RT_ERR(value_after_insert < keys_data.dim(0), "inserted count exceed maximum hash size");
insert_count_ += keys.dim(0);
insert_count_ += keys.dim(0);
TV_ASSERT_RT_ERR(keys.dtype() == keys_data.dtype(), "keys dtype not equal to", keys_data.dtype());
}}
}}
auto N = keys.dim(0);
auto N = keys.dim(0);
TV_ASSERT_RT_ERR(keys.itemsize() == key_itemsize_, "keys itemsize not equal to", key_itemsize_);
if (!values.empty()){{
if (!values.empty()){{
TV_ASSERT_RT_ERR(values.itemsize() == value_itemsize_, "values itemsize not equal to", value_itemsize_);
TV_ASSERT_RT_ERR(values.itemsize() == value_itemsize_, "values itemsize not equal to", value_itemsize_);
TV_ASSERT_RT_ERR(keys.dim(0) == values.dim(0), "number of key and value must same");
TV_ASSERT_RT_ERR(keys.dim(0) == values.dim(0), "number of key and value must same");
...
@@ -231,10 +231,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -231,10 +231,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
auto custream = reinterpret_cast<cudaStream_t>(stream);
auto custream = reinterpret_cast<cudaStream_t>(stream);
"""
)
"""
)
for
k_items
in
_dispatch
_ints
(
code
,
[
4
,
8
]
,
"keys_data.
itemsiz
e()"
):
for
k_items
in
_dispatch
(
code
,
self
.
valid_hash_key_types
,
"keys_data.
dtyp
e()"
):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
using K = tv::hash::itemsize_to_unsigned_t<
{
k_items
}
>;
using K =
{
k_items
}
;
constexpr K kEmptyKey = std::numeric_limits<K>::max();
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
const K* key_ptr = reinterpret_cast<const K*>(keys.raw_data());
const K* key_ptr = reinterpret_cast<const K*>(keys.raw_data());
...
@@ -248,7 +247,7 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -248,7 +247,7 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
using table_t =
using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
kEmptyKey
, false>;
tv::hash::default_empty_key_v<K>
, false>;
tv::cuda::Launch launcher(N, custream);
tv::cuda::Launch launcher(N, custream);
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
launcher(tv::hash::insert_split<table_t>, table, key_ptr, value_ptr, size_t(N));
launcher(tv::hash::insert_split<table_t>, table, key_ptr, value_ptr, size_t(N));
...
@@ -279,6 +278,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -279,6 +278,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
TV_ASSERT_RT_ERR(values.itemsize() == value_itemsize_, "values itemsize not equal to", value_itemsize_);
TV_ASSERT_RT_ERR(values.itemsize() == value_itemsize_, "values itemsize not equal to", value_itemsize_);
TV_ASSERT_RT_ERR(N == values.dim(0) && is_empty.dim(0) == N, "number of key and value must same");
TV_ASSERT_RT_ERR(N == values.dim(0) && is_empty.dim(0) == N, "number of key and value must same");
auto is_empty_ptr = is_empty.data_ptr<uint8_t>();
auto is_empty_ptr = is_empty.data_ptr<uint8_t>();
if (!is_cpu){{
TV_ASSERT_RT_ERR(keys.dtype() == keys_data.dtype(), "keys dtype not equal to", keys_data.dtype());
}}
"""
)
"""
)
with
code
.
if_
(
"is_cpu"
):
with
code
.
if_
(
"is_cpu"
):
map_name
=
"cpu_map"
map_name
=
"cpu_map"
...
@@ -304,10 +306,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -304,10 +306,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
auto custream = reinterpret_cast<cudaStream_t>(stream);
auto custream = reinterpret_cast<cudaStream_t>(stream);
"""
)
"""
)
for
k_items
in
_dispatch
_ints
(
code
,
[
4
,
8
]
,
"keys_data.
itemsiz
e()"
):
for
k_items
in
_dispatch
(
code
,
self
.
valid_hash_key_types
,
"keys_data.
dtyp
e()"
):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
using K = tv::hash::itemsize_to_unsigned_t<
{
k_items
}
>;
using K =
{
k_items
}
;
constexpr K kEmptyKey = std::numeric_limits<K>::max();
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
K* key_ptr = reinterpret_cast<K*>(keys.raw_data());
K* key_ptr = reinterpret_cast<K*>(keys.raw_data());
...
@@ -319,7 +320,7 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -319,7 +320,7 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
V* value_ptr = reinterpret_cast<V*>(values.raw_data());
V* value_ptr = reinterpret_cast<V*>(values.raw_data());
using table_t =
using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
kEmptyKey
, false>;
tv::hash::default_empty_key_v<K>
, false>;
tv::cuda::Launch launcher(N, custream);
tv::cuda::Launch launcher(N, custream);
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
launcher(tv::hash::query_split<table_t>, table, key_ptr, value_ptr, is_empty_ptr, size_t(N));
launcher(tv::hash::query_split<table_t>, table, key_ptr, value_ptr, is_empty_ptr, size_t(N));
...
@@ -361,11 +362,12 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -361,11 +362,12 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
TV_ASSERT_RT_ERR(count.device() == 0, "count must be cuda");
TV_ASSERT_RT_ERR(count.device() == 0, "count must be cuda");
auto custream = reinterpret_cast<cudaStream_t>(stream);
auto custream = reinterpret_cast<cudaStream_t>(stream);
"""
)
"""
)
for
k_items
in
_dispatch
_ints
(
code
,
[
4
,
8
]
,
"keys_data.
itemsiz
e()"
):
for
k_items
in
_dispatch
(
code
,
self
.
valid_hash_key_types
,
"keys_data.
dtyp
e()"
):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
using K = tv::hash::itemsize_to_unsigned_t<
{
k_items
}
>;
using K =
{
k_items
}
;
constexpr K kEmptyKey = std::numeric_limits<K>::max();
using Kunsigned = tv::hash::itemsize_to_unsigned_t<sizeof(K)>;
auto count_ptr = count.data_ptr<K>();
auto count_ptr = count.data_ptr<Kunsigned>();
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
"""
)
"""
)
...
@@ -376,10 +378,10 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -376,10 +378,10 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
V* value_data_ptr = reinterpret_cast<V*>(values_data.raw_data());
V* value_data_ptr = reinterpret_cast<V*>(values_data.raw_data());
using table_t =
using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
kEmptyKey
, false>;
tv::hash::default_empty_key_v<K>
, false>;
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
tv::cuda::Launch launcher(table.size(), custream);
tv::cuda::Launch launcher(table.size(), custream);
launcher(tv::hash::assign_arange_split<table_t, K>, table, count_ptr);
launcher(tv::hash::assign_arange_split<table_t, K
unsigned
>, table, count_ptr);
"""
)
"""
)
else
:
else
:
code
.
raw
(
f
"""
code
.
raw
(
f
"""
...
@@ -426,7 +428,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -426,7 +428,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
TV_ASSERT_RT_ERR(keys.itemsize() == key_itemsize_, "keys itemsize not equal to", key_itemsize_);
TV_ASSERT_RT_ERR(keys.itemsize() == key_itemsize_, "keys itemsize not equal to", key_itemsize_);
TV_ASSERT_RT_ERR(values.itemsize() == value_itemsize_, "values itemsize not equal to", value_itemsize_);
TV_ASSERT_RT_ERR(values.itemsize() == value_itemsize_, "values itemsize not equal to", value_itemsize_);
TV_ASSERT_RT_ERR(N == values.dim(0), "number of key and value must same");
TV_ASSERT_RT_ERR(N == values.dim(0), "number of key and value must same");
if (!is_cpu){{
TV_ASSERT_RT_ERR(keys.dtype() == keys_data.dtype(), "keys dtype not equal to", keys_data.dtype());
}}
"""
)
"""
)
with
code
.
if_
(
"is_cpu"
):
with
code
.
if_
(
"is_cpu"
):
map_name
=
"cpu_map"
map_name
=
"cpu_map"
...
@@ -450,12 +454,12 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -450,12 +454,12 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
auto custream = reinterpret_cast<cudaStream_t>(stream);
auto custream = reinterpret_cast<cudaStream_t>(stream);
"""
)
"""
)
for
k_items
in
_dispatch
_ints
(
code
,
[
4
,
8
]
,
"keys_data.
itemsiz
e()"
):
for
k_items
in
_dispatch
(
code
,
self
.
valid_hash_key_types
,
"keys_data.
dtyp
e()"
):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
using K = tv::hash::itemsize_to_unsigned_t<
{
k_items
}
>;
using K =
{
k_items
}
;
auto count_ptr = count.data_ptr<K>();
using Kunsigned = tv::hash::itemsize_to_unsigned_t<sizeof(K)>;
auto count_ptr = count.data_ptr<Kunsigned>();
constexpr K kEmptyKey = std::numeric_limits<K>::max();
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
K* key_ptr = reinterpret_cast<K*>(keys.raw_data());
K* key_ptr = reinterpret_cast<K*>(keys.raw_data());
...
@@ -467,10 +471,10 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -467,10 +471,10 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
V* value_ptr = reinterpret_cast<V*>(values.raw_data());
V* value_ptr = reinterpret_cast<V*>(values.raw_data());
using table_t =
using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
kEmptyKey
, false>;
tv::hash::default_empty_key_v<K>
, false>;
tv::cuda::Launch launcher(N, custream);
tv::cuda::Launch launcher(N, custream);
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
launcher(tv::hash::iterate_table_split<table_t, K>, table, key_ptr, value_ptr, size_t(N), count_ptr);
launcher(tv::hash::iterate_table_split<table_t, K
unsigned
>, table, key_ptr, value_ptr, size_t(N), count_ptr);
"""
)
"""
)
else
:
else
:
code
.
raw
(
f
"""
code
.
raw
(
f
"""
...
@@ -523,10 +527,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -523,10 +527,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
auto custream = reinterpret_cast<cudaStream_t>(stream);
auto custream = reinterpret_cast<cudaStream_t>(stream);
"""
)
"""
)
for
k_items
in
_dispatch
_ints
(
code
,
[
4
,
8
]
,
"keys_data.
itemsiz
e()"
):
for
k_items
in
_dispatch
(
code
,
self
.
valid_hash_key_types
,
"keys_data.
dtyp
e()"
):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
using K = tv::hash::itemsize_to_unsigned_t<
{
k_items
}
>;
using K =
{
k_items
}
;
constexpr K kEmptyKey = std::numeric_limits<K>::max();
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
const K* key_ptr = reinterpret_cast<const K*>(keys.raw_data());
const K* key_ptr = reinterpret_cast<const K*>(keys.raw_data());
...
@@ -538,7 +541,7 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -538,7 +541,7 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
const V* value_ptr = reinterpret_cast<const V*>(values.raw_data());
const V* value_ptr = reinterpret_cast<const V*>(values.raw_data());
using table_t =
using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
kEmptyKey
, false>;
tv::hash::default_empty_key_v<K>
, false>;
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
tv::cuda::Launch launcher(N, custream);
tv::cuda::Launch launcher(N, custream);
launcher(insert_exist_keys_kernel<table_t>, table, key_ptr, value_ptr, is_empty_ptr, size_t(N));
launcher(insert_exist_keys_kernel<table_t>, table, key_ptr, value_ptr, is_empty_ptr, size_t(N));
...
...
spconv/csrc/sparse/all.py
View file @
7af751dc
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
cumm.common
import
TensorView
,
TensorViewCPU
,
TensorViewKernel
,
ThrustLib
from
cumm.common
import
TensorView
,
TensorViewCPU
,
TensorViewKernel
,
ThrustLib
,
GemmBasicHost
from
cumm.conv.bases
import
ConvOpType
,
NHWC
from
cumm.conv.bases
import
ConvOpType
,
NHWC
from
cumm.conv.params
import
ConvProblem
from
cumm.conv.params
import
ConvProblem
from
cumm
import
dtypes
from
cumm
import
dtypes
...
@@ -23,7 +23,8 @@ from .pointops import Point2Voxel, Point2VoxelCPU
...
@@ -23,7 +23,8 @@ from .pointops import Point2Voxel, Point2VoxelCPU
from
.indices
import
SparseConvIndicesKernel
,
CudaCommonKernel
,
SparseConvIndicesCPU
from
.indices
import
SparseConvIndicesKernel
,
CudaCommonKernel
,
SparseConvIndicesCPU
from
.maxpool
import
IndiceMaxPool
,
IndiceMaxPoolCPU
from
.maxpool
import
IndiceMaxPool
,
IndiceMaxPoolCPU
from
.gather
import
GatherCPU
from
.gather
import
GatherCPU
from
.alloc
import
ExternalAllocator
,
ThrustAllocator
from
spconv.constants
import
SpconvAllocatorKeys
class
CustomThrustLib
(
pccm
.
Class
):
class
CustomThrustLib
(
pccm
.
Class
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -31,7 +32,7 @@ class CustomThrustLib(pccm.Class):
...
@@ -31,7 +32,7 @@ class CustomThrustLib(pccm.Class):
self
.
add_dependency
(
ThrustLib
)
self
.
add_dependency
(
ThrustLib
)
# https://github.com/NVIDIA/thrust/issues/1401#issuecomment-806403746
# https://github.com/NVIDIA/thrust/issues/1401#issuecomment-806403746
if
compat
.
InLinux
:
if
compat
.
InLinux
:
self
.
build_meta
.
add_cflags
(
"nvcc"
,
"-Xcompiler"
,
"-fno-gnu-unique"
)
self
.
build_meta
.
add_
public_
cflags
(
"nvcc"
,
"-Xcompiler"
,
"-fno-gnu-unique"
)
class
ThrustCustomAllocatorV2
(
pccm
.
Class
,
pccm
.
pybind
.
PybindClassMixin
):
class
ThrustCustomAllocatorV2
(
pccm
.
Class
,
pccm
.
pybind
.
PybindClassMixin
):
...
@@ -65,13 +66,13 @@ class ThrustCustomAllocatorV2(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -65,13 +66,13 @@ class ThrustCustomAllocatorV2(pccm.Class, pccm.pybind.PybindClassMixin):
code
=
pccm
.
FunctionCode
()
code
=
pccm
.
FunctionCode
()
code
.
arg
(
"ptr"
,
"char *"
)
code
.
arg
(
"ptr"
,
"char *"
)
code
.
arg
(
"num_bytes"
,
"size_t"
)
code
.
arg
(
"num_bytes"
,
"size_t"
)
return
code
return
code
class
SpconvOps
(
pccm
.
Class
):
class
SpconvOps
(
pccm
.
Class
):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
self
.
add_dependency
(
ThrustCustomAllocatorV2
)
self
.
add_dependency
(
ThrustCustomAllocatorV2
,
ExternalAllocator
,
GemmBasicHost
,
ThrustAllocator
)
self
.
ndims
=
[
1
,
2
,
3
,
4
]
self
.
ndims
=
[
1
,
2
,
3
,
4
]
for
ndim
in
self
.
ndims
:
for
ndim
in
self
.
ndims
:
p2v
=
Point2Voxel
(
dtypes
.
float32
,
ndim
)
p2v
=
Point2Voxel
(
dtypes
.
float32
,
ndim
)
...
@@ -167,8 +168,8 @@ class SpconvOps(pccm.Class):
...
@@ -167,8 +168,8 @@ class SpconvOps(pccm.Class):
@
pccm
.
cuda
.
static_function
@
pccm
.
cuda
.
static_function
def
generate_conv_inds_stage2
(
self
):
def
generate_conv_inds_stage2
(
self
):
code
=
pccm
.
FunctionCode
()
code
=
pccm
.
FunctionCode
()
code
.
arg
(
"indices, hashdata"
,
"tv::Tensor"
)
code
.
arg
(
"indices, hashdata
_k, hashdata_v
"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pairs, indice_pairs_uniq, out_inds"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pairs, indice_pairs_uniq,
indice_pairs_uniq_before_sort,
out_inds"
,
"tv::Tensor"
)
code
.
arg
(
"num_out_act"
,
"int"
)
code
.
arg
(
"num_out_act"
,
"int"
)
code
.
arg
(
"batch_size"
,
"int"
)
code
.
arg
(
"batch_size"
,
"int"
)
code
.
arg
(
"output_dims, input_dims"
,
f
"std::vector<int>"
)
code
.
arg
(
"output_dims, input_dims"
,
f
"std::vector<int>"
)
...
@@ -198,8 +199,9 @@ class SpconvOps(pccm.Class):
...
@@ -198,8 +199,9 @@ class SpconvOps(pccm.Class):
padding_[i] = padding[i];
padding_[i] = padding[i];
dilation_[i] = dilation[i];
dilation_[i] = dilation[i];
}}
}}
return SpconvIndices
{
ndim
}
D::generate_conv_inds_stage2(indices, hashdata,
return SpconvIndices
{
ndim
}
D::generate_conv_inds_stage2(indices,
indice_pairs, indice_pairs_uniq, out_inds, num_out_act,
hashdata_k, hashdata_v, indice_pairs,
indice_pairs_uniq, indice_pairs_uniq_before_sort, out_inds, num_out_act,
batch_size, output_dims_, input_dims_,
batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, transposed, stream_int);
ksize_, stride_, padding_, dilation_, transposed, stream_int);
}}
}}
...
@@ -260,9 +262,9 @@ class SpconvOps(pccm.Class):
...
@@ -260,9 +262,9 @@ class SpconvOps(pccm.Class):
code
=
pccm
.
FunctionCode
()
code
=
pccm
.
FunctionCode
()
if
CUMM_CPU_ONLY_BUILD
:
if
CUMM_CPU_ONLY_BUILD
:
return
code
.
make_invalid
()
return
code
.
make_invalid
()
code
.
arg
(
"indices, hashdata"
,
"tv::Tensor"
)
code
.
arg
(
"indices, hashdata
_k, hashdata_v
"
,
"tv::Tensor"
)
code
.
arg
(
code
.
arg
(
"indice_pairs_fwd, indice_pairs_bwd, indice_pairs_uniq, out_inds"
,
"indice_pairs_fwd, indice_pairs_bwd, indice_pairs_uniq,
indice_pairs_uniq_before_sort,
out_inds"
,
"tv::Tensor"
)
"tv::Tensor"
)
code
.
arg
(
"mask_fwd, mask_bwd"
,
"tv::Tensor"
)
code
.
arg
(
"mask_fwd, mask_bwd"
,
"tv::Tensor"
)
code
.
arg
(
"num_out_act"
,
"int"
)
code
.
arg
(
"num_out_act"
,
"int"
)
...
@@ -291,8 +293,11 @@ class SpconvOps(pccm.Class):
...
@@ -291,8 +293,11 @@ class SpconvOps(pccm.Class):
padding_[i] = padding[i];
padding_[i] = padding[i];
dilation_[i] = dilation[i];
dilation_[i] = dilation[i];
}}
}}
return SpconvIndices
{
ndim
}
D::generate_conv_inds_stage2_mask(indices, hashdata,
return SpconvIndices
{
ndim
}
D::generate_conv_inds_stage2_mask(
indice_pairs_fwd, indice_pairs_bwd, indice_pairs_uniq, out_inds, mask_fwd, mask_bwd,
indices, hashdata_k, hashdata_v,
indice_pairs_fwd, indice_pairs_bwd,
indice_pairs_uniq, indice_pairs_uniq_before_sort,
out_inds, mask_fwd, mask_bwd,
num_out_act, batch_size, output_dims_, input_dims_,
num_out_act, batch_size, output_dims_, input_dims_,
ksize_, stride_, padding_, dilation_, transposed, stream_int);
ksize_, stride_, padding_, dilation_, transposed, stream_int);
}}
}}
...
@@ -307,7 +312,7 @@ class SpconvOps(pccm.Class):
...
@@ -307,7 +312,7 @@ class SpconvOps(pccm.Class):
code
=
pccm
.
FunctionCode
()
code
=
pccm
.
FunctionCode
()
if
CUMM_CPU_ONLY_BUILD
:
if
CUMM_CPU_ONLY_BUILD
:
return
code
.
make_invalid
()
return
code
.
make_invalid
()
code
.
arg
(
"indices, hashdata"
,
"tv::Tensor"
)
code
.
arg
(
"indices, hashdata
_k, hashdata_v
"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pairs, out_inds, indice_num_per_loc"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pairs, out_inds, indice_num_per_loc"
,
"tv::Tensor"
)
code
.
arg
(
"batch_size"
,
"int"
)
code
.
arg
(
"batch_size"
,
"int"
)
code
.
arg
(
"input_dims"
,
f
"std::vector<int>"
)
code
.
arg
(
"input_dims"
,
f
"std::vector<int>"
)
...
@@ -331,7 +336,8 @@ class SpconvOps(pccm.Class):
...
@@ -331,7 +336,8 @@ class SpconvOps(pccm.Class):
ksize_[i] = ksize[i];
ksize_[i] = ksize[i];
dilation_[i] = dilation[i];
dilation_[i] = dilation[i];
}}
}}
return SpconvIndices
{
ndim
}
D::generate_subm_conv_inds(indices, hashdata,
return SpconvIndices
{
ndim
}
D::generate_subm_conv_inds(indices,
hashdata_k, hashdata_v,
indice_pairs, out_inds, indice_num_per_loc,
indice_pairs, out_inds, indice_num_per_loc,
batch_size, input_dims_,
batch_size, input_dims_,
ksize_, dilation_, indice_pair_mask, backward,
ksize_, dilation_, indice_pair_mask, backward,
...
@@ -566,7 +572,7 @@ class SpconvOps(pccm.Class):
...
@@ -566,7 +572,7 @@ class SpconvOps(pccm.Class):
}}
}}
"""
"""
code
.
add_dependency
(
ThrustLib
,
TensorViewKernel
)
code
.
add_dependency
(
Custom
ThrustLib
,
TensorViewKernel
)
code
.
add_param_class
(
"cudakers"
,
CudaCommonKernel
())
code
.
add_param_class
(
"cudakers"
,
CudaCommonKernel
())
code
.
raw
(
f
"""
code
.
raw
(
f
"""
cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream);
cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream);
...
@@ -588,14 +594,15 @@ class SpconvOps(pccm.Class):
...
@@ -588,14 +594,15 @@ class SpconvOps(pccm.Class):
"""
)
"""
)
return
code
.
ret
(
"tv::Tensor"
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
def
sort_1d_by_key_allocator_template
(
self
,
use_allocator
:
bool
):
@
pccm
.
cuda
.
static_function
def
sort_1d_by_key_allocator
(
self
):
code
=
pccm
.
FunctionCode
()
code
=
pccm
.
FunctionCode
()
if
CUMM_CPU_ONLY_BUILD
:
if
CUMM_CPU_ONLY_BUILD
:
return
code
.
make_invalid
()
return
code
.
make_invalid
()
code
.
arg
(
"data"
,
"tv::Tensor"
)
code
.
arg
(
"data"
,
"tv::Tensor"
)
code
.
arg
(
"alloc_func"
,
"std::function<std::uintptr_t(std::size_t)>"
)
if
not
use_allocator
:
code
.
arg
(
"alloc_func"
,
"std::function<std::uintptr_t(std::size_t)>"
)
else
:
code
.
arg
(
"allocator"
,
"ThrustAllocator&"
)
code
.
arg
(
"indices"
,
code
.
arg
(
"indices"
,
"tv::Tensor"
,
"tv::Tensor"
,
...
@@ -614,10 +621,13 @@ class SpconvOps(pccm.Class):
...
@@ -614,10 +621,13 @@ class SpconvOps(pccm.Class):
}}
}}
}}
}}
"""
"""
code
.
add_dependency
(
ThrustLib
,
TensorViewKernel
)
code
.
add_dependency
(
Custom
ThrustLib
,
TensorViewKernel
)
code
.
add_param_class
(
"cudakers"
,
CudaCommonKernel
())
code
.
add_param_class
(
"cudakers"
,
CudaCommonKernel
())
if
not
use_allocator
:
code
.
raw
(
f
"""
ThrustCustomAllocatorV2 allocator{{alloc_func}};
"""
)
code
.
raw
(
f
"""
code
.
raw
(
f
"""
ThrustCustomAllocatorV2 allocator{{alloc_func}};
cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream);
cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream);
if (indices.empty()){{
if (indices.empty()){{
indices = tv::empty({{data.dim(0)}}, tv::int32, 0);
indices = tv::empty({{data.dim(0)}}, tv::int32, 0);
...
@@ -638,6 +648,19 @@ class SpconvOps(pccm.Class):
...
@@ -638,6 +648,19 @@ class SpconvOps(pccm.Class):
"""
)
"""
)
return
code
.
ret
(
"tv::Tensor"
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
@
pccm
.
cuda
.
static_function
def
sort_1d_by_key_allocator
(
self
):
# for python
return
self
.
sort_1d_by_key_allocator_template
(
False
)
@
pccm
.
cuda
.
static_function
def
sort_1d_by_key_allocator_v2
(
self
):
# for cpp only
return
self
.
sort_1d_by_key_allocator_template
(
True
)
@
pccm
.
pybind
.
mark
@
pccm
.
pybind
.
mark
@
pccm
.
cuda
.
static_function
@
pccm
.
cuda
.
static_function
def
sort_1d_by_key_split
(
self
):
def
sort_1d_by_key_split
(
self
):
...
@@ -694,14 +717,15 @@ class SpconvOps(pccm.Class):
...
@@ -694,14 +717,15 @@ class SpconvOps(pccm.Class):
"""
)
"""
)
return
code
.
ret
(
"tv::Tensor"
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
def
sort_1d_by_key_split_allocator_template
(
self
,
use_allocator
:
bool
):
@
pccm
.
cuda
.
static_function
def
sort_1d_by_key_split_allocator
(
self
):
code
=
pccm
.
FunctionCode
()
code
=
pccm
.
FunctionCode
()
if
CUMM_CPU_ONLY_BUILD
:
if
CUMM_CPU_ONLY_BUILD
:
return
code
.
make_invalid
()
return
code
.
make_invalid
()
code
.
arg
(
"data"
,
"tv::Tensor"
)
code
.
arg
(
"data"
,
"tv::Tensor"
)
code
.
arg
(
"alloc_func"
,
"std::function<std::uintptr_t(std::size_t)>"
)
if
not
use_allocator
:
code
.
arg
(
"alloc_func"
,
"std::function<std::uintptr_t(std::size_t)>"
)
else
:
code
.
arg
(
"allocator"
,
"ThrustAllocator&"
)
code
.
arg
(
"mask"
,
"tv::Tensor"
)
code
.
arg
(
"mask"
,
"tv::Tensor"
)
...
@@ -727,9 +751,11 @@ class SpconvOps(pccm.Class):
...
@@ -727,9 +751,11 @@ class SpconvOps(pccm.Class):
"""
"""
code
.
add_dependency
(
CustomThrustLib
,
TensorViewKernel
)
code
.
add_dependency
(
CustomThrustLib
,
TensorViewKernel
)
code
.
add_param_class
(
"cudakers"
,
CudaCommonKernel
())
code
.
add_param_class
(
"cudakers"
,
CudaCommonKernel
())
if
not
use_allocator
:
code
.
raw
(
f
"""
ThrustCustomAllocatorV2 allocator{{alloc_func}};
"""
)
code
.
raw
(
f
"""
code
.
raw
(
f
"""
ThrustCustomAllocatorV2 allocator{{alloc_func}};
cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream);
cudaStream_t stream_cu = reinterpret_cast<cudaStream_t>(stream);
// auto timer = tv::CudaContextTimer<>();
// auto timer = tv::CudaContextTimer<>();
if (indices.empty()){{
if (indices.empty()){{
...
@@ -755,6 +781,18 @@ class SpconvOps(pccm.Class):
...
@@ -755,6 +781,18 @@ class SpconvOps(pccm.Class):
"""
)
"""
)
return
code
.
ret
(
"tv::Tensor"
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
@
pccm
.
cuda
.
static_function
def
sort_1d_by_key_split_allocator
(
self
):
return
self
.
sort_1d_by_key_split_allocator_template
(
False
)
@
pccm
.
pybind
.
mark
@
pccm
.
cuda
.
static_function
def
sort_1d_by_key_split_allocator_v2
(
self
):
return
self
.
sort_1d_by_key_split_allocator_template
(
True
)
@
pccm
.
pybind
.
mark
@
pccm
.
pybind
.
mark
@
pccm
.
cuda
.
static_function
@
pccm
.
cuda
.
static_function
def
count_bits
(
self
):
def
count_bits
(
self
):
...
@@ -947,3 +985,411 @@ class SpconvOps(pccm.Class):
...
@@ -947,3 +985,411 @@ class SpconvOps(pccm.Class):
"""
)
"""
)
code
.
raw
(
f
"""TV_THROW_RT_ERR("unknown ndim", ndim);"""
)
code
.
raw
(
f
"""TV_THROW_RT_ERR("unknown ndim", ndim);"""
)
return
code
.
ret
(
"std::tuple<tv::Tensor, tv::Tensor, tv::Tensor>"
)
return
code
.
ret
(
"std::tuple<tv::Tensor, tv::Tensor, tv::Tensor>"
)
@
pccm
.
pybind
.
mark
@
pccm
.
static_function
def
get_int32_max
(
self
):
code
=
pccm
.
FunctionCode
()
code
.
raw
(
f
"return std::numeric_limits<int>::max();"
)
return
code
.
ret
(
"int"
)
@
pccm
.
static_function
def
get_conv_output_size
(
self
):
code
=
pccm
.
FunctionCode
()
code
.
arg
(
"input_dims"
,
f
"std::vector<int>"
)
code
.
arg
(
"ksize, stride, padding, dilation"
,
f
"std::vector<int>"
)
code
.
raw
(
f
"""
int ndim = input_dims.size();
std::vector<int> out_dims;
for (int i = 0; i < ndim; ++i){{
if (ksize[i] == -1){{
out_dims.push_back(1);
}}else{{
auto size = (input_dims[i] + 2 * padding[i] - dilation[i] *
(ksize[i] - 1) - 1) / stride[i] + 1;
out_dims.push_back(size);
}}
}}
return out_dims;
"""
)
return
code
.
ret
(
"std::vector<int>"
)
@
pccm
.
static_function
def
get_deconv_output_size
(
self
):
code
=
pccm
.
FunctionCode
()
code
.
arg
(
"input_dims"
,
f
"std::vector<int>"
)
code
.
arg
(
"ksize, stride, padding, dilation, output_padding"
,
f
"std::vector<int>"
)
code
.
raw
(
f
"""
int ndim = input_dims.size();
std::vector<int> out_dims;
for (int i = 0; i < ndim; ++i){{
if (ksize[i] == -1){{
TV_THROW_INVALID_ARG("kernel size can't be -1");
}}else{{
auto size = (input_dims[i] - 1) * stride[i] - 2 * padding[i] + ksize[
i] + output_padding[i];
out_dims.push_back(size);
}}
}}
return out_dims;
"""
)
return
code
.
ret
(
"std::vector<int>"
)
@
pccm
.
cuda
.
static_function
def
apply_thrust_unique_to_indice_pairs_uniq
(
self
):
code
=
pccm
.
code
()
code
.
add_dependency
(
CustomThrustLib
)
code
.
arg
(
"data"
,
"tv::Tensor"
)
code
.
arg
(
"allocator"
,
"ThrustAllocator&"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
)
code
.
raw
(
f
"""
int num_out_act = 0;
int uniq_size = data.dim(0);
tv::dispatch<int32_t, int64_t>(data.dtype(), [&](auto I){{
using T = TV_DECLTYPE(I);
thrust::device_ptr<T> ptr_tr(data.data_ptr<T>());
auto thrust_ctx = thrust::cuda::par(allocator).on(reinterpret_cast<cudaStream_t>(stream_int));
thrust::sort(thrust_ctx, ptr_tr, ptr_tr + uniq_size);
auto new_end = thrust::unique(thrust_ctx, ptr_tr, ptr_tr + uniq_size);
num_out_act = new_end - ptr_tr - 1;
}});
return num_out_act;
"""
)
return
code
.
ret
(
"int"
)
@
pccm
.
pybind
.
mark
@
pccm
.
static_function
def
get_indice_pairs_implicit_gemm
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"allocator"
,
"ExternalAllocator&"
)
code
.
arg
(
"indices"
,
"tv::Tensor"
)
code
.
arg
(
"batch_size"
,
"int"
)
code
.
arg
(
"input_dims"
,
f
"std::vector<int>"
)
code
.
arg
(
"algo"
,
"int"
)
code
.
arg
(
"ksize, stride, padding, dilation, out_padding"
,
f
"std::vector<int>"
)
code
.
arg
(
"subm, transposed, is_train"
,
f
"bool"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"""
throw std::runtime_error("this function can only be used with CUDA.")
"""
)
return
code
.
ret
(
"tv::Tensor"
)
code
.
raw
(
f
"""
auto tvctx = tv::Context();
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int));
auto conv_algo = static_cast<tv::gemm::SparseConvAlgo>(algo);
int kv = std::accumulate(ksize.begin(), ksize.end(), 1, std::multiplies<int>());
TV_ASSERT_RT_ERR(kv <= 32, "currently only support ksize < 32");
std::vector<int64_t> input_dims_i64(input_dims.begin(), input_dims.end());
int64_t spatial_volume = std::accumulate(input_dims_i64.begin(),
input_dims_i64.end(), int64_t(1), std::multiplies<int64_t>());
bool use_int64_hash_k = spatial_volume >= int64_t(std::numeric_limits<int>::max());
tv::DType indice_uniq_dtype = use_int64_hash_k ? tv::int64 : tv::int32;
std::vector<int> out_shape;
if (!subm){{
if (transposed){{
out_shape = get_deconv_output_size(input_dims, ksize, stride, padding, dilation, out_padding);
}}else{{
out_shape = get_conv_output_size(input_dims, ksize, stride, padding, dilation);
}}
}}else{{
out_shape = input_dims;
}}
for (auto& v : out_shape){{
if (v <= 0){{
TV_THROW_RT_ERR("your out spatial shape", out_shape, "ratch zero!, input shape:", input_dims);
}}
}}
TV_ASSERT_RT_ERR(conv_algo == tv::gemm::SparseConvAlgo::kMaskImplicitGemm ||
conv_algo == tv::gemm::SparseConvAlgo::kMaskSplitImplicitGemm, "only support implicit gemm");
bool is_mask_split = conv_algo == tv::gemm::SparseConvAlgo::kMaskSplitImplicitGemm;
int mask_split_count = is_mask_split ? 2 : 1;
tv::Tensor pair;
if (subm){{
pair = allocator.full_int(
{
pccm
.
literal
(
SpconvAllocatorKeys
.
Pair
)
}
,
{{2, kv, indices.dim(0)}}, -1, indices.dtype(), indices.device());
}}else{{
pair = allocator.full_int(
{
pccm
.
literal
(
SpconvAllocatorKeys
.
Pair
)
}
,
{{kv, indices.dim(0)}}, -1, indices.dtype(), indices.device());
}}
auto indice_num_per_loc = allocator.zeros(
{
pccm
.
literal
(
SpconvAllocatorKeys
.
IndiceNumPerLoc
)
}
,
{{kv}}, indices.dtype(), indices.device());
tv::Tensor mask_tensor = tv::zeros({{mask_split_count}}, tv::uint32, -1);
auto mask_tensor_ptr = mask_tensor.data_ptr<uint32_t>();
if (is_mask_split){{
auto kv_div_2 = kv / 2;
auto remain = kv - kv_div_2;
uint64_t mask_np_1 = 1;
uint64_t first = ((mask_np_1 << remain) - 1);
uint64_t second = ((mask_np_1 << kv_div_2) - 1) << remain;
mask_tensor_ptr[0] = uint32_t(first);
mask_tensor_ptr[1] = uint32_t(second);
}}
else{{
mask_tensor_ptr[1] = 0xffffffff;
}}
tv::Tensor out_inds;
ThrustAllocator thrustalloc(allocator);
if (subm){{
ExternalAllocator::guard_t hash_k_guard, hash_v_gurad, hash_kv_gurad;
out_inds = indices;
int num_points = out_inds.dim(0);
tv::Tensor hash_k, hash_v;
if (use_int64_hash_k){{
hash_k_guard = allocator.empty_guard({{num_points * 2}}, tv::int64, 0);
hash_v_gurad = allocator.empty_guard({{num_points * 2}}, tv::int32, 0);
hash_k = hash_k_guard->tensor;
hash_v = hash_v_gurad->tensor;
}}else{{
hash_kv_gurad = allocator.empty_guard({{2, num_points * 2}}, tv::int32, 0);
hash_k = hash_kv_gurad->tensor[0];
hash_v = hash_kv_gurad->tensor[1];
}}
auto pair_mask = allocator.empty(
{
pccm
.
literal
(
SpconvAllocatorKeys
.
PairMask
)
}
,
{{mask_split_count, indices.dim(0)}}, tv::uint32, 0);
generate_subm_conv_inds(indices, hash_k, hash_v, pair, out_inds, indice_num_per_loc,
batch_size, input_dims, ksize, dilation, pair_mask, false, stream_int);
auto mask_argsort = allocator.empty(
{
pccm
.
literal
(
SpconvAllocatorKeys
.
MaskArgSort
)
}
,
{{mask_split_count, out_inds.dim(0)}}, tv::uint32, 0);
for (int j = 0; j < mask_split_count; ++j){{
sort_1d_by_key_allocator_v2(pair_mask[j], thrustalloc, mask_argsort[j], stream_int);
}}
}}else{{
auto pair_bwd = pair;
auto indice_pairs_uniq_guard = allocator.empty_guard({{int64_t(pair.numel() + 1)}}, indice_uniq_dtype, 0);
auto indice_pairs_uniq_bkp_guard = allocator.empty_guard({{int64_t(pair.numel() + 1)}}, indice_uniq_dtype, 0);
auto indice_pairs_uniq = indice_pairs_uniq_guard->tensor;
generate_conv_inds_mask_stage1(indices, pair_bwd, indice_pairs_uniq,
indice_num_per_loc, batch_size, out_shape, input_dims, ksize,
stride, padding, dilation, transposed, stream_int);
indice_pairs_uniq_bkp_guard->tensor.copy_(indice_pairs_uniq, tvctx);
// TODO pytorch unique may be faster?
int num_act_out = apply_thrust_unique_to_indice_pairs_uniq(indice_pairs_uniq, thrustalloc, stream_int) - 1;
indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_act_out);
out_inds = allocator.empty(
{
pccm
.
literal
(
SpconvAllocatorKeys
.
OutIndices
)
}
,
{{num_act_out, indices.dim(1)}}, indices.dtype(), 0);
auto pair_fwd = allocator.full_int(
{
pccm
.
literal
(
SpconvAllocatorKeys
.
PairFwd
)
}
,
{{kv, num_act_out}}, -1, indices.dtype(), indices.device());
auto pair_mask_fwd = allocator.zeros(
{
pccm
.
literal
(
SpconvAllocatorKeys
.
PairMask
)
}
,
{{mask_split_count, num_act_out}}, tv::uint32, 0);
auto pair_mask_bwd = tv::Tensor();
if (is_train){{
pair_mask_bwd = allocator.zeros(
{
pccm
.
literal
(
SpconvAllocatorKeys
.
PairMaskBwd
)
}
,
{{mask_split_count, indices.dim(0)}}, tv::uint32, 0);
}}
ExternalAllocator::guard_t hash_k_guard, hash_v_gurad, hash_kv_gurad;
tv::Tensor hash_k, hash_v;
if (use_int64_hash_k){{
hash_k_guard = allocator.empty_guard({{num_act_out * 2}}, tv::int64, 0);
hash_v_gurad = allocator.empty_guard({{num_act_out * 2}}, tv::int32, 0);
hash_k = hash_k_guard->tensor;
hash_v = hash_v_gurad->tensor;
}}else{{
hash_kv_gurad = allocator.empty_guard({{2, num_act_out * 2}}, tv::int32, 0);
hash_k = hash_kv_gurad->tensor[0];
hash_v = hash_kv_gurad->tensor[1];
}}
generate_conv_inds_mask_stage2(indices, hash_k, hash_v, pair_fwd, pair_bwd,
indice_pairs_uniq, indice_pairs_uniq_bkp_guard->tensor,
out_inds, pair_mask_fwd, pair_mask_bwd, num_act_out,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
transposed, stream_int);
auto mask_argsort_fwd = allocator.empty(
{
pccm
.
literal
(
SpconvAllocatorKeys
.
MaskArgSort
)
}
,
{{mask_split_count, out_inds.dim(0)}}, tv::uint32, 0);
tv::Tensor mask_argsort_bwd = tv::Tensor();
if (is_train){{
mask_argsort_bwd = allocator.zeros(
{
pccm
.
literal
(
SpconvAllocatorKeys
.
MaskArgSortBwd
)
}
,
{{mask_split_count, indices.dim(0)}}, tv::uint32, 0);
}}
if (is_mask_split){{
for (int j = 0; j < mask_split_count; ++j){{
if (!is_train){{
sort_1d_by_key_split_allocator_v2(pair_mask_fwd[j], thrustalloc,
mask_tensor[j], mask_argsort_fwd[j], stream_int);
}}else{{
sort_1d_by_key_split_allocator_v2(pair_mask_fwd[j], thrustalloc,
mask_tensor[j], mask_argsort_fwd[j], stream_int);
sort_1d_by_key_split_allocator_v2(pair_mask_bwd[j], thrustalloc,
mask_tensor[j], mask_argsort_bwd[j], stream_int);
}}
}}
}}else{{
if (!is_train){{
sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc,
mask_argsort_fwd[0], stream_int);
}}else{{
sort_1d_by_key_allocator_v2(pair_mask_fwd[0], thrustalloc,
mask_argsort_fwd[0], stream_int);
sort_1d_by_key_allocator_v2(pair_mask_bwd[0], thrustalloc,
mask_argsort_bwd[0], stream_int);
}}
}}
}}
return mask_tensor;
"""
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
@
pccm
.
static_function
def
get_indice_pairs
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"allocator"
,
"ExternalAllocator&"
)
code
.
arg
(
"indices"
,
"tv::Tensor"
)
code
.
arg
(
"batch_size"
,
"int"
)
code
.
arg
(
"input_dims"
,
f
"std::vector<int>"
)
code
.
arg
(
"algo"
,
"int"
)
code
.
arg
(
"ksize, stride, padding, dilation, out_padding"
,
f
"std::vector<int>"
)
code
.
arg
(
"subm, transposed"
,
f
"bool"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"""
throw std::runtime_error("this function can only be used with CUDA.")
"""
)
return
code
code
.
raw
(
f
"""
int kv = std::accumulate(ksize.begin(), ksize.end(), 1, std::multiplies<int>());
TV_ASSERT_RT_ERR(kv <= 32, "currently only support ksize < 32");
auto conv_algo = static_cast<tv::gemm::SparseConvAlgo>(algo);
TV_ASSERT_RT_ERR(conv_algo == tv::gemm::SparseConvAlgo::kNative, "only support kNative");
std::vector<int64_t> input_dims_i64(input_dims.begin(), input_dims.end());
int64_t spatial_volume = std::accumulate(input_dims_i64.begin(),
input_dims_i64.end(), int64_t(1), std::multiplies<int64_t>());
bool use_int64_hash_k = spatial_volume >= int64_t(std::numeric_limits<int>::max());
tv::DType indice_uniq_dtype = use_int64_hash_k ? tv::int64 : tv::int32;
std::vector<int> out_shape;
if (!subm){{
if (transposed){{
out_shape = get_deconv_output_size(input_dims, ksize, stride, padding, dilation, out_padding);
}}else{{
out_shape = get_conv_output_size(input_dims, ksize, stride, padding, dilation);
}}
}}else{{
out_shape = input_dims;
}}
for (auto& v : out_shape){{
if (v <= 0){{
TV_THROW_RT_ERR("your out spatial shape", out_shape, "ratch zero!, input shape:", input_dims);
}}
}}
tv::Tensor pair;
pair = allocator.full_int(
{
pccm
.
literal
(
SpconvAllocatorKeys
.
Pair
)
}
,
{{2, kv, indices.dim(0)}}, -1, indices.dtype(), indices.device());
auto indice_num_per_loc = allocator.zeros(
{
pccm
.
literal
(
SpconvAllocatorKeys
.
IndiceNumPerLoc
)
}
,
{{kv}}, indices.dtype(), indices.device());
tv::Tensor out_inds;
"""
)
with
code
.
if_
(
"subm"
):
code
.
raw
(
f
"""
if (indices.is_cpu()){{
generate_subm_conv_inds_cpu(indices, pair, out_inds, indice_num_per_loc,
batch_size, input_dims, ksize, dilation);
}}
"""
)
if
not
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"""
else {{
ExternalAllocator::guard_t hash_k_guard, hash_v_gurad, hash_kv_gurad;
out_inds = indices;
int num_points = out_inds.dim(0);
tv::Tensor hash_k, hash_v;
if (use_int64_hash_k){{
hash_k_guard = allocator.empty_guard({{num_points * 2}}, tv::int64, 0);
hash_v_gurad = allocator.empty_guard({{num_points * 2}}, tv::int32, 0);
hash_k = hash_k_guard->tensor;
hash_v = hash_v_gurad->tensor;
}}else{{
hash_kv_gurad = allocator.empty_guard({{2, num_points * 2}}, tv::int32, 0);
hash_k = hash_kv_gurad->tensor[0];
hash_v = hash_kv_gurad->tensor[1];
}}
generate_subm_conv_inds(indices, hash_k, hash_v, pair, out_inds, indice_num_per_loc,
batch_size, input_dims, ksize, dilation, tv::Tensor(), false, stream_int);
}}
"""
)
else
:
code
.
raw
(
f
"""
else {{
TV_THROW_RT_ERR("not implemented for CPU ONLY build.")
}}
"""
)
with
code
.
else_
():
code
.
raw
(
f
"""
if (indices.is_cpu()){{
out_inds = allocator.empty(
{
pccm
.
literal
(
SpconvAllocatorKeys
.
OutIndices
)
}
,
{{kv * indices.dim(0), indices.dim(1)}}, indices.dtype(), -1);
generate_conv_inds_cpu(indices, pair, out_inds, indice_num_per_loc,
batch_size, out_shape, input_dims, ksize,
stride, padding, dilation, transposed);
}}
"""
)
if
not
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"""
else {{
ThrustAllocator thrustalloc(allocator);
auto tvctx = tv::Context();
tvctx.set_cuda_stream(reinterpret_cast<cudaStream_t>(stream_int));
auto indice_pairs_uniq_guard = allocator.empty_guard({{int64_t(pair.numel() + 1)}}, indice_uniq_dtype, 0);
auto indice_pairs_uniq = indice_pairs_uniq_guard->tensor;
auto indice_pairs_uniq_bkp_guard = allocator.empty_guard({{int64_t(pair.numel() + 1)}}, indice_uniq_dtype, 0);
generate_conv_inds_stage1(indices, pair, indice_pairs_uniq,
indice_num_per_loc, batch_size, out_shape, input_dims, ksize,
stride, padding, dilation, transposed, stream_int);
indice_pairs_uniq_bkp_guard->tensor.copy_(indice_pairs_uniq, tvctx);
// TODO pytorch unique may be faster?
int num_act_out = apply_thrust_unique_to_indice_pairs_uniq(indice_pairs_uniq, thrustalloc, stream_int) - 1;
indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_act_out);
out_inds = allocator.empty(
{
pccm
.
literal
(
SpconvAllocatorKeys
.
OutIndices
)
}
,
{{num_act_out, indices.dim(1)}}, indices.dtype(), 0);
ExternalAllocator::guard_t hash_k_guard, hash_v_gurad, hash_kv_gurad;
tv::Tensor hash_k, hash_v;
if (use_int64_hash_k){{
hash_k_guard = allocator.empty_guard({{num_act_out * 2}}, tv::int64, 0);
hash_v_gurad = allocator.empty_guard({{num_act_out * 2}}, tv::int32, 0);
hash_k = hash_k_guard->tensor;
hash_v = hash_v_gurad->tensor;
}}else{{
hash_kv_gurad = allocator.empty_guard({{2, num_act_out * 2}}, tv::int32, 0);
hash_k = hash_kv_gurad->tensor[0];
hash_v = hash_kv_gurad->tensor[1];
}}
generate_conv_inds_stage2(indices, hash_k, hash_v, pair,
indice_pairs_uniq, indice_pairs_uniq_bkp_guard->tensor,
out_inds, num_act_out,
batch_size, out_shape, input_dims, ksize, stride, padding, dilation,
transposed, stream_int);
}}
"""
)
else
:
code
.
raw
(
f
"""
else {{
TV_THROW_RT_ERR("not implemented for CPU ONLY build.")
}}
"""
)
code
.
raw
(
f
"""
return;
"""
)
return
code
@
pccm
.
pybind
.
mark
@
pccm
.
static_function
def
test_allocator
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"allocator"
,
"ExternalAllocator&"
)
code
.
raw
(
f
"""
auto guard = allocator.zeros_guard({{1, 2, 3}}, tv::int32, 0);
tv::ssprint("????");
"""
)
return
code
\ No newline at end of file
spconv/csrc/sparse/alloc.py
0 → 100644
View file @
7af751dc
import
pccm
from
cumm.common
import
TensorView
,
TensorViewCPU
,
TensorViewKernel
,
ThrustLib
class
ExternalAllocatorGuard
(
pccm
.
Class
):
def
__init__
(
self
):
super
().
__init__
()
self
.
add_dependency
(
TensorView
)
self
.
add_member
(
"tensor"
,
"tv::Tensor"
)
self
.
add_member
(
"free_func"
,
"std::function<void(tv::Tensor)>"
)
@
pccm
.
constructor
def
ctor
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"ten"
,
"tv::Tensor"
)
code
.
arg
(
"free_func"
,
"std::function<void(tv::Tensor)>"
)
code
.
ctor_init
(
"tensor"
,
"ten"
)
code
.
ctor_init
(
"free_func"
,
"free_func"
)
return
code
@
pccm
.
constructor
def
dctor
(
self
):
code
=
pccm
.
code
()
return
code
@
pccm
.
destructor
def
dtor
(
self
):
code
=
pccm
.
code
()
code
.
raw
(
f
"""
if (!tensor.empty() && free_func){{
free_func(tensor);
}}
"""
)
return
code
class
ExternalAllocator
(
pccm
.
Class
):
def
__init__
(
self
):
super
().
__init__
()
self
.
add_dependency
(
TensorView
,
ExternalAllocatorGuard
)
self
.
use_shared
=
True
self
.
ptr_type
=
"unique"
if
self
.
use_shared
:
self
.
ptr_type
=
"shared"
self
.
add_typedef
(
"guard_t"
,
f
"std::
{
self
.
ptr_type
}
_ptr<ExternalAllocatorGuard>"
)
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
@
pccm
.
member_function
(
virtual
=
True
,
pure_virtual
=
True
)
def
zeros
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"name"
,
"std::string"
)
code
.
arg
(
"shape"
,
"std::vector<int64_t>"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
@
pccm
.
member_function
(
virtual
=
True
,
pure_virtual
=
True
)
def
empty
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"name"
,
"std::string"
)
code
.
arg
(
"shape"
,
"std::vector<int64_t>"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
@
pccm
.
member_function
(
virtual
=
True
,
pure_virtual
=
True
)
def
full_int
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"name"
,
"std::string"
)
code
.
arg
(
"shape"
,
"std::vector<int64_t>"
)
code
.
arg
(
"value"
,
"int"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
@
pccm
.
member_function
(
virtual
=
True
,
pure_virtual
=
True
)
def
full_float
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"name"
,
"std::string"
)
code
.
arg
(
"shape"
,
"std::vector<int64_t>"
)
code
.
arg
(
"value"
,
"float"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
@
pccm
.
member_function
(
virtual
=
True
,
pure_virtual
=
True
)
def
free
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"ten"
,
"tv::Tensor"
)
return
code
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
@
pccm
.
member_function
(
virtual
=
True
,
pure_virtual
=
True
)
def
free_noexcept
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"ten"
,
"tv::Tensor"
)
return
code
@
pccm
.
member_function
def
zeros_guard
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"shape"
,
"std::vector<int64_t>"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
code
.
raw
(
f
"""
// "" means temp memory
auto ten = zeros("", shape, dtype, device);
return std::make_
{
self
.
ptr_type
}
<ExternalAllocatorGuard>(ten, [this](tv::Tensor ten){{
this->free(ten);
}});
"""
)
return
code
.
ret
(
f
"std::
{
self
.
ptr_type
}
_ptr<ExternalAllocatorGuard>"
)
@
pccm
.
member_function
def
empty_guard
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"shape"
,
"std::vector<int64_t>"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
code
.
raw
(
f
"""
auto ten = empty("", shape, dtype, device);
return std::make_
{
self
.
ptr_type
}
<ExternalAllocatorGuard>(ten, [this](tv::Tensor ten){{
this->free(ten);
}});
"""
)
return
code
.
ret
(
f
"std::
{
self
.
ptr_type
}
_ptr<ExternalAllocatorGuard>"
)
@
pccm
.
member_function
def
full_int_guard
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"shape"
,
"std::vector<int64_t>"
)
code
.
arg
(
"value"
,
"int"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
code
.
raw
(
f
"""
auto ten = full_int("", shape, value, dtype, device);
return std::make_
{
self
.
ptr_type
}
<ExternalAllocatorGuard>(ten, [this](tv::Tensor ten){{
this->free(ten);
}});
"""
)
return
code
.
ret
(
f
"std::
{
self
.
ptr_type
}
_ptr<ExternalAllocatorGuard>"
)
@
pccm
.
member_function
def
full_float_guard
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"shape"
,
"std::vector<int64_t>"
)
code
.
arg
(
"value"
,
"int"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
code
.
raw
(
f
"""
auto ten = full_float("", shape, value, dtype, device);
return std::make_
{
self
.
ptr_type
}
<ExternalAllocatorGuard>(ten, [this](tv::Tensor t){{
this->free(t);
}});
"""
)
return
code
.
ret
(
f
"std::
{
self
.
ptr_type
}
_ptr<ExternalAllocatorGuard>"
)
class
ThrustAllocator
(
pccm
.
Class
):
def
__init__
(
self
):
super
().
__init__
()
self
.
add_dependency
(
TensorView
,
ExternalAllocator
)
self
.
add_include
(
"functional"
,
"memory"
)
self
.
add_member
(
"allocator_"
,
"ExternalAllocator&"
,)
self
.
add_typedef
(
"value_type"
,
"char"
)
@
pccm
.
constructor
def
ctor
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"allocator"
,
"ExternalAllocator&"
)
code
.
ctor_init
(
"allocator_"
,
"allocator"
)
return
code
@
pccm
.
member_function
def
allocate
(
self
):
code
=
pccm
.
FunctionCode
()
code
.
arg
(
"num_bytes"
,
"std::ptrdiff_t"
)
code
.
ret
(
"char*"
)
code
.
raw
(
f
"""
auto ten = allocator_.empty("", {{num_bytes}}, tv::uint8, 0);
return reinterpret_cast<char*>(ten.raw_data());
"""
)
return
code
@
pccm
.
member_function
def
deallocate
(
self
):
code
=
pccm
.
FunctionCode
()
code
.
arg
(
"ptr"
,
"char *"
)
code
.
arg
(
"num_bytes"
,
"size_t"
)
code
.
raw
(
f
"""
return allocator_.free_noexcept(tv::from_blob(ptr, {{num_bytes}}, tv::uint8, 0));
"""
)
return
code
spconv/csrc/sparse/convops.py
0 → 100644
View file @
7af751dc
import
pccm
from
cumm.gemm.main
import
GemmMainUnitTest
from
cumm.conv.main
import
ConvMainUnitTest
from
.alloc
import
ExternalAllocator
from
spconv.core
import
ConvAlgo
from
spconv.constants
import
SpconvAllocatorKeys
from
cumm.constants
import
CUMM_CPU_ONLY_BUILD
from
cumm.common
import
GemmBasicHost
,
TensorView
,
NlohmannJson
class
GemmTuneResult
(
pccm
.
Class
,
pccm
.
pybind
.
PybindClassMixin
):
def
__init__
(
self
):
super
().
__init__
()
self
.
add_dependency
(
GemmBasicHost
,
TensorView
)
self
.
add_pybind_member
(
"algo_desp"
,
"tv::gemm::GemmAlgoDesp"
)
self
.
add_pybind_member
(
"arch"
,
"std::tuple<int, int>"
)
self
.
add_pybind_member
(
"splitk"
,
"int"
)
@
pccm
.
pybind
.
mark
@
pccm
.
member_function
def
is_valid
(
self
):
code
=
pccm
.
code
()
code
.
raw
(
f
"return splitk > 0 && std::get<0>(arch) > 0"
)
return
code
@
pccm
.
pybind
.
mark
@
pccm
.
constructor
def
defaultctor
(
self
):
code
=
pccm
.
code
()
code
.
ctor_init
(
"algo_desp"
,
"tv::gemm::GemmAlgoDesp()"
)
code
.
ctor_init
(
"arch"
,
"std::make_tuple(-1, -1)"
)
code
.
ctor_init
(
"splitk"
,
"-1"
)
return
code
@
pccm
.
pybind
.
mark
@
pccm
.
constructor
def
ctor
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"algo_desp"
,
"tv::gemm::GemmAlgoDesp"
,
pyanno
=
"cumm.tensorview.gemm.GemmAlgoDesp"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"splitk"
,
"int"
)
code
.
ctor_init
(
"algo_desp"
,
"algo_desp"
)
code
.
ctor_init
(
"arch"
,
"arch"
)
code
.
ctor_init
(
"splitk"
,
"splitk"
)
return
code
class
ConvTuneResult
(
pccm
.
Class
,
pccm
.
pybind
.
PybindClassMixin
):
def
__init__
(
self
):
super
().
__init__
()
self
.
add_dependency
(
GemmBasicHost
,
TensorView
)
self
.
add_pybind_member
(
"algo_desp"
,
"tv::gemm::ConvAlgoDesp"
)
self
.
add_pybind_member
(
"arch"
,
"std::tuple<int, int>"
)
self
.
add_pybind_member
(
"splitk"
,
"int"
)
@
pccm
.
pybind
.
mark
@
pccm
.
constructor
def
defaultctor
(
self
):
code
=
pccm
.
code
()
code
.
ctor_init
(
"algo_desp"
,
"tv::gemm::ConvAlgoDesp()"
)
code
.
ctor_init
(
"arch"
,
"std::make_tuple(-1, -1)"
)
code
.
ctor_init
(
"splitk"
,
"-1"
)
return
code
@
pccm
.
pybind
.
mark
@
pccm
.
constructor
def
ctor
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"algo_desp"
,
"tv::gemm::ConvAlgoDesp"
,
pyanno
=
"cumm.tensorview.gemm.ConvAlgoDesp"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"splitk"
,
"int"
)
code
.
ctor_init
(
"algo_desp"
,
"algo_desp"
)
code
.
ctor_init
(
"arch"
,
"arch"
)
code
.
ctor_init
(
"splitk"
,
"splitk"
)
return
code
@
pccm
.
pybind
.
mark
@
pccm
.
member_function
def
is_valid
(
self
):
code
=
pccm
.
code
()
code
.
raw
(
f
"return splitk > 0 && std::get<0>(arch) > 0"
)
return
code
class
GemmTunerSimple
(
pccm
.
ParameterizedClass
):
def
__init__
(
self
,
gemm_cu
:
GemmMainUnitTest
,
conv_cu
:
ConvMainUnitTest
):
super
().
__init__
()
self
.
add_dependency
(
ExternalAllocator
,
GemmTuneResult
,
ConvTuneResult
,
TensorView
)
self
.
add_param_class
(
"gemm"
,
gemm_cu
,
"GemmMain"
)
self
.
add_param_class
(
"conv"
,
conv_cu
,
"ConvMain"
)
self
.
add_include
(
"tensorview/utility/tuplehash.h"
)
self
.
add_member
(
"desps_"
,
"std::vector<tv::gemm::GemmAlgoDesp>"
)
self
.
add_member
(
"nvrtc_progs_"
,
"std::unordered_map<std::string, tv::NVRTCProgram>"
)
self
.
add_member
(
"nvrtc_caches_"
,
"std::unordered_map<std::tuple<std::string, int, int, std::uintptr_t>, tv::NVRTCModule>"
)
@
pccm
.
pybind
.
mark
@
pccm
.
constructor
def
ctor
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"desps"
,
"std::vector<tv::gemm::GemmAlgoDesp>"
)
code
.
arg
(
"nvrtc_progs"
,
"std::unordered_map<std::string, std::string>"
)
code
.
ctor_init
(
"desps_"
,
"desps"
)
code
.
raw
(
f
"""
for (auto& v : nvrtc_progs){{
const uint8_t* code_ptr = reinterpret_cast<const uint8_t*>(v.second.c_str());
nvrtc_progs_.insert(v.first, tv::NVRTCProgram::from_binary(code_ptr, v.second.size()));
}}
"""
)
return
code
@
pccm
.
member_function
def
get_all_available
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"a, b, c"
,
"tv::Tensor"
)
code
.
arg
(
"trans_a, trans_b, trans_c"
,
"bool"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"nvrtc_progs"
,
"std::unordered_map<std::string, std::string>"
)
code
.
ctor_init
(
"desps_"
,
"desps"
)
code
.
raw
(
f
"""
for (auto& v : nvrtc_progs){{
const uint8_t* code_ptr = reinterpret_cast<const uint8_t*>(v.second.c_str());
nvrtc_progs_.insert(v.first, tv::NVRTCProgram::from_binary(code_ptr, v.second.size()));
}}
"""
)
return
code
class
ConvGemmOps
(
pccm
.
ParameterizedClass
):
def
__init__
(
self
,
gemm_cu
:
GemmMainUnitTest
,
conv_cu
:
ConvMainUnitTest
):
super
().
__init__
()
self
.
add_dependency
(
ExternalAllocator
,
GemmTuneResult
,
ConvTuneResult
)
self
.
add_param_class
(
"gemm"
,
gemm_cu
,
"GemmMain"
)
self
.
add_param_class
(
"conv"
,
conv_cu
,
"ConvMain"
)
@
pccm
.
pybind
.
mark
@
pccm
.
static_function
def
indice_conv
(
self
):
"""1. this function need to take a out features
that from subm first mm.
2. this function don't support CPU.
"""
code
=
pccm
.
code
()
code
.
arg
(
"allocator"
,
"ExternalAllocator&"
)
code
.
arg
(
"out_features_after_mm"
,
"tv::Tensor"
)
code
.
arg
(
"features, filters, indice_pairs"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pair_num"
,
"tv::Tensor"
)
code
.
arg
(
"num_activate_out"
,
"int"
)
code
.
arg
(
"inverse"
,
"bool"
,
"false"
)
code
.
arg
(
"subm"
,
"bool"
,
"false"
)
code
.
arg
(
"algo"
,
"int"
,
f
"
{
ConvAlgo
.
Native
.
value
}
"
)
code
.
arg
(
"filter_hwio"
,
"bool"
,
"false"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"""
throw std::runtime_error("this function can only be used with CUDA.")
"""
)
return
code
.
ret
(
"tv::Tensor"
)
code
.
raw
(
f
"""
TV_ASSERT_RT_ERR(!features.is_cpu(), "this function don't support cpu.")
int out_channel;
if (filter_hwio){{
out_channel = filters.dim(-1);
}}else{{
out_channel = filters.dim(-2);
}}
filters = filters.view(-1, filters.dim(-2), filters.dim(-1));
int kv = filters.dim(0);
int kv_center = kv / 2;
tv::Tensor out_features;
if (kv == 1 && subm){{
return;
}}
auto indice_pair_num_cpu = indice_pair_num.cpu();
auto indice_pair_num_cpu_ptr = indice_pair_num_cpu.data_ptr<int>();
int maxnhot = 0;
bool all_zero = true;
for (int i = 0; i < kv; ++i){{
if (indice_pair_num_cpu_ptr[i] != 0){{
all_zero = false;
maxnhot = std::max(maxnhot, indice_pair_num_cpu_ptr[i]);
}}
}}
if (subm && all_zero){{
return;
}}
bool inited = subm;
auto a = features;
auto c = out_features;
auto pair_in = indice_pairs[int(inverse)];
auto pair_out = indice_pairs[int(!inverse)];
"""
)
return
code
spconv/csrc/sparse/cpu_core.py
View file @
7af751dc
...
@@ -23,13 +23,8 @@ class OMPLib(pccm.Class):
...
@@ -23,13 +23,8 @@ class OMPLib(pccm.Class):
self
.
add_dependency
(
TensorView
)
self
.
add_dependency
(
TensorView
)
self
.
add_include
(
"tensorview/parallel/all.h"
)
self
.
add_include
(
"tensorview/parallel/all.h"
)
if
compat
.
InWindows
:
if
compat
.
InWindows
:
self
.
build_meta
.
add_cflags
(
"cl"
,
"/openmp"
)
self
.
build_meta
.
add_
public_
cflags
(
"cl"
,
"/openmp"
)
else
:
else
:
self
.
build_meta
.
add_cflags
(
"g++"
,
"-fopenmp"
)
self
.
build_meta
.
add_public_cflags
(
"g++"
,
"-fopenmp"
)
self
.
build_meta
.
add_cflags
(
"clang++"
,
"-fopenmp"
)
self
.
build_meta
.
add_public_cflags
(
"clang++"
,
"-fopenmp"
)
if
"g++"
not
in
self
.
build_meta
.
compiler_to_ldflags
:
self
.
build_meta
.
add_ldflags
(
"g++,clang++"
,
"-fopenmp"
)
self
.
build_meta
.
compiler_to_ldflags
[
"g++"
]
=
[]
self
.
build_meta
.
compiler_to_ldflags
[
"g++"
].
extend
([
"-fopenmp"
])
if
"clang++"
not
in
self
.
build_meta
.
compiler_to_ldflags
:
self
.
build_meta
.
compiler_to_ldflags
[
"clang++"
]
=
[]
self
.
build_meta
.
compiler_to_ldflags
[
"clang++"
].
extend
([
"-fopenmp"
])
spconv/csrc/sparse/indices.py
View file @
7af751dc
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
# limitations under the License.
# limitations under the License.
import
contextlib
import
contextlib
from
cumm.conv.bases
import
ConvEnum
from
cumm.gemm.core.metaarray
import
MetaArray
,
seq
from
cumm.gemm.core.metaarray
import
MetaArray
,
seq
from
cumm
import
dtypes
from
cumm
import
dtypes
import
pccm
import
pccm
...
@@ -255,7 +254,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -255,7 +254,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
self
.
add_param_class
(
"spinds"
,
self
.
loc_iter
,
"ConvLocIter"
)
self
.
add_param_class
(
"spinds"
,
self
.
loc_iter
,
"ConvLocIter"
)
self
.
add_param_class
(
"spinds"
,
problem
,
"ConvProblem"
)
self
.
add_param_class
(
"spinds"
,
problem
,
"ConvProblem"
)
self
.
add_param_class
(
"cudakers"
,
CudaCommonKernel
())
self
.
add_param_class
(
"cudakers"
,
CudaCommonKernel
())
self
.
add_include
(
"tensorview/hash/ops.h"
)
self
.
ndim
=
problem
.
ndim
self
.
ndim
=
problem
.
ndim
self
.
dtype_indices
=
dtype_indices
self
.
dtype_indices
=
dtype_indices
self
.
dtype_indices_uniq
=
dtype_indices
self
.
dtype_indices_uniq
=
dtype_indices
...
@@ -265,13 +264,14 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -265,13 +264,14 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@
pccm
.
cuda
.
cuda_global_function
@
pccm
.
cuda
.
cuda_global_function
def
calc_conv_indices_stage1
(
self
):
def
calc_conv_indices_stage1
(
self
):
code
=
pccm
.
FunctionCode
()
code
=
pccm
.
FunctionCode
()
code
.
targ
(
"TIndiceUniq"
)
code
.
arg
(
"loc_iter"
,
f
"ConvLocIter"
)
# [N, ndim + 1]
code
.
arg
(
"loc_iter"
,
f
"ConvLocIter"
)
# [N, ndim + 1]
code
.
arg
(
"indices_in"
,
f
"const int*"
)
# [N, ndim + 1]
code
.
arg
(
"indices_in"
,
f
"const int*"
)
# [N, ndim + 1]
code
.
arg
(
"indice_pairs"
,
code
.
arg
(
"indice_pairs"
,
f
"
{
self
.
dtype_indices
}
*"
)
# [2, kernelProd, MaxSize]
f
"
{
self
.
dtype_indices
}
*"
)
# [2, kernelProd, MaxSize]
code
.
arg
(
"indice_pairs_for_uniq"
,
code
.
arg
(
"indice_pairs_for_uniq"
,
f
"
{
self
.
dtype_i
ndice
s
}
*"
)
# [2, kernelProd, MaxSize]
f
"
TI
ndice
Uniq
*"
)
# [2, kernelProd, MaxSize]
code
.
arg
(
"indice_num_per_loc"
,
f
"int*"
)
# [kernelProd]
code
.
arg
(
"indice_num_per_loc"
,
f
"int*"
)
# [kernelProd]
code
.
arg
(
"num_indices_in"
,
"int"
)
code
.
arg
(
"num_indices_in"
,
"int"
)
...
@@ -295,10 +295,10 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -295,10 +295,10 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
}}
}}
if (valid){{
if (valid){{
int old_num = tv::cuda::atomicAggInc(indice_num_per_loc + filter_offset);
int old_num = tv::cuda::atomicAggInc(indice_num_per_loc + filter_offset);
{
self
.
dtype_indices
}
offset = loc_iter.layout_npq(npq_offset);
int64_t
offset = loc_iter.layout_npq(npq_offset);
if (old_num < indices_pair_size){{
if (old_num < indices_pair_size){{
indice_pairs[filter_offset_mul_indices_pair_size + old_num] = i;
indice_pairs[filter_offset_mul_indices_pair_size + old_num] = i;
indice_pairs[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size + old_num] = offset;
//
indice_pairs[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size + old_num] = offset;
indice_pairs_for_uniq[filter_offset_mul_indices_pair_size + old_num] = offset;
indice_pairs_for_uniq[filter_offset_mul_indices_pair_size + old_num] = offset;
}}
}}
}}
}}
...
@@ -314,7 +314,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -314,7 +314,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code
.
arg
(
"table"
,
f
"TTable"
)
# [N, ndim + 1]
code
.
arg
(
"table"
,
f
"TTable"
)
# [N, ndim + 1]
code
.
arg
(
"indices_out"
,
f
"int*"
)
# [N, ndim + 1]
code
.
arg
(
"indices_out"
,
f
"int*"
)
# [N, ndim + 1]
code
.
arg
(
"indice_pairs_for_uniq"
,
code
.
arg
(
"indice_pairs_for_uniq"
,
f
"const
{
self
.
dtype_indices
}
*"
)
# [2, kernelProd, MaxSize]
f
"const
typename TTable::key_type
*"
)
# [2, kernelProd, MaxSize]
code
.
arg
(
"layout_npq"
,
code
.
arg
(
"layout_npq"
,
f
"spinds::LayoutNPQ"
)
# [2, kernelProd, MaxSize]
f
"spinds::LayoutNPQ"
)
# [2, kernelProd, MaxSize]
...
@@ -323,7 +323,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -323,7 +323,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
for (int output_index : tv::KernelLoopX<int>(num_indices)) {{
for (int output_index : tv::KernelLoopX<int>(num_indices)) {{
{
self
.
dtype_indices
}
output_coord_offset = indice_pairs_for_uniq[output_index];
auto
output_coord_offset = indice_pairs_for_uniq[output_index];
layout_npq.inverse(output_coord_offset, indices_out +
{
self
.
ndim
+
1
}
* output_index);
layout_npq.inverse(output_coord_offset, indices_out +
{
self
.
ndim
+
1
}
* output_index);
table.insert(output_coord_offset, output_index);
table.insert(output_coord_offset, output_index);
}}
}}
...
@@ -334,20 +334,24 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -334,20 +334,24 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
def
calc_conv_indices_stage2
(
self
):
def
calc_conv_indices_stage2
(
self
):
code
=
pccm
.
FunctionCode
()
code
=
pccm
.
FunctionCode
()
code
.
targ
(
"TTable"
)
code
.
targ
(
"TTable"
)
code
.
arg
(
"table"
,
f
"TTable"
)
# [N, ndim + 1]
code
.
arg
(
"table"
,
f
"TTable"
)
# [N, ndim + 1]
code
.
arg
(
"indice_pairs_out_part"
,
f
"int*"
)
# [2, kernelProd, MaxSize]
code
.
arg
(
"indice_pairs_uniq_before_sort"
,
f
"const typename TTable::key_type*"
)
# [kernelProd, MaxSize]
code
.
arg
(
"indice_pairs_out_part"
,
f
"int*"
)
# [kernelProd, MaxSize]
code
.
arg
(
"num_indices_in"
,
"int"
)
code
.
arg
(
"num_indices_in"
,
"int"
)
code
.
arg
(
"indices_pair_size"
,
"int"
)
code
.
arg
(
"indices_pair_size"
,
"int"
)
# TODO use block instead of filter_offset?
# TODO use block instead of filter_offset?
code
.
raw
(
f
"""
code
.
raw
(
f
"""
int filter_offset = blockIdx.y;
int filter_offset = blockIdx.y;
auto indice_pairs_out_part_filter = indice_pairs_out_part + filter_offset * indices_pair_size;
auto indice_pairs_out_part_filter = indice_pairs_out_part + filter_offset * indices_pair_size;
auto indice_pairs_uniq_before_sort_filter = indice_pairs_uniq_before_sort + filter_offset * indices_pair_size;
for (int i : tv::KernelLoopX<int>(num_indices_in)) {{
for (int i : tv::KernelLoopX<int>(num_indices_in)) {{
{
self
.
dtype_indices
}
output_coord_offset = indice_pairs_
out_pa
rt_filter[i];
{
self
.
dtype_indices
}
output_coord_offset = indice_pairs_
uniq_before_so
rt_filter[i];
if (output_coord_offset
> -1
){{
if (output_coord_offset
!= std::numeric_limits<typename TTable::key_type>::max()
){{
auto
ptr
= table.lookup_
ptr
(output_coord_offset);
auto
table_offset
= table.lookup_
offset
(output_coord_offset);
if (
ptr
){{
if (
table_offset != -1
){{
indice_pairs_out_part_filter[i] =
ptr->second
;
indice_pairs_out_part_filter[i] =
table.value_ptr()[table_offset]
;
}}
}}
}}
}}
}}
}}
...
@@ -357,13 +361,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -357,13 +361,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@
pccm
.
cuda
.
cuda_global_function
@
pccm
.
cuda
.
cuda_global_function
def
calc_conv_indices_stage1_mask
(
self
):
def
calc_conv_indices_stage1_mask
(
self
):
code
=
pccm
.
FunctionCode
()
code
=
pccm
.
FunctionCode
()
code
.
targ
(
"TIndiceUniq"
)
code
.
arg
(
"loc_iter"
,
f
"ConvLocIter"
)
# [N, ndim + 1]
code
.
arg
(
"loc_iter"
,
f
"ConvLocIter"
)
# [N, ndim + 1]
code
.
arg
(
"indices_in"
,
f
"const int*"
)
# [N, ndim + 1]
code
.
arg
(
"indices_in"
,
f
"const int*"
)
# [N, ndim + 1]
code
.
arg
(
"indice_pairs_bwd"
,
code
.
arg
(
"indice_pairs_bwd"
,
f
"
{
self
.
dtype_indices
}
*"
)
# [
2,
kernelProd, MaxSize]
f
"
{
self
.
dtype_indices
}
*"
)
# [kernelProd, MaxSize]
code
.
arg
(
"indice_pairs_for_uniq"
,
code
.
arg
(
"indice_pairs_for_uniq"
,
f
"
{
self
.
dtype_i
ndice
s
}
*"
)
# [2, kernelProd, MaxSize]
f
"
TI
ndice
Uniq
*"
)
# [2, kernelProd, MaxSize]
code
.
arg
(
"indice_num_per_loc"
,
f
"int*"
)
# [kernelProd]
code
.
arg
(
"indice_num_per_loc"
,
f
"int*"
)
# [kernelProd]
code
.
arg
(
"num_indices_in"
,
"int"
)
code
.
arg
(
"num_indices_in"
,
"int"
)
...
@@ -386,12 +392,12 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -386,12 +392,12 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
}}
}}
if (valid){{
if (valid){{
int old_num = tv::cuda::atomicAggInc(indice_num_per_loc + filter_offset);
int old_num = tv::cuda::atomicAggInc(indice_num_per_loc + filter_offset);
{
self
.
dtype_i
ndice
s
}
output_coord_offset = loc_iter.layout_npq(npq_offset);
TI
ndice
Uniq
output_coord_offset = loc_iter.layout_npq(npq_offset);
// if (old_num < indices_pair_size){{
// if (old_num < indices_pair_size){{
// indice_pairs[filter_offset_mul_indices_pair_size + old_num] = i;
// indice_pairs[filter_offset_mul_indices_pair_size + old_num] = i;
// indice_pairs_bwd[filter_offset_mul_indices_pair_size + input_index] = output_coord_offset;
indice_pairs_
bwd
[filter_offset_mul_indices_pair_size +
input_index
] = output_coord_offset;
//
indice_pairs_
for_uniq
[filter_offset_mul_indices_pair_size +
old_num
] = output_coord_offset;
indice_pairs_for_uniq[filter_offset_mul_indices_pair_size +
old_num
] = output_coord_offset;
indice_pairs_for_uniq[filter_offset_mul_indices_pair_size +
input_index
] = output_coord_offset;
// }}
// }}
}}
}}
}}
}}
...
@@ -407,6 +413,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -407,6 +413,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
f
"int*"
)
# [kernelProd, MaxSize], inp -> out
f
"int*"
)
# [kernelProd, MaxSize], inp -> out
code
.
arg
(
"indice_pairs_bwd"
,
code
.
arg
(
"indice_pairs_bwd"
,
f
"int*"
)
# [kernelProd, MaxSize], out -> inp
f
"int*"
)
# [kernelProd, MaxSize], out -> inp
code
.
arg
(
"indice_pairs_uniq_before_sort"
,
f
"const typename TTable::key_type*"
)
# [kernelProd, MaxSize]
code
.
arg
(
"mask_fwd"
,
f
"uint32_t*"
)
# [kernelProd]
code
.
arg
(
"mask_fwd"
,
f
"uint32_t*"
)
# [kernelProd]
code
.
arg
(
"mask_bwd"
,
f
"uint32_t*"
)
# [kernelProd]
code
.
arg
(
"mask_bwd"
,
f
"uint32_t*"
)
# [kernelProd]
...
@@ -422,12 +429,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -422,12 +429,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
auto indice_pairs_fwd_filter = indice_pairs_fwd + filter_offset * num_indices_out;
auto indice_pairs_fwd_filter = indice_pairs_fwd + filter_offset * num_indices_out;
auto indice_pairs_bwd_filter = indice_pairs_bwd + filter_offset * num_indices_in;
auto indice_pairs_bwd_filter = indice_pairs_bwd + filter_offset * num_indices_in;
auto indice_pairs_uniq_before_sort_filter = indice_pairs_uniq_before_sort + filter_offset * num_indices_in;
for (int input_index : tv::KernelLoopX<int>(num_indices_in)) {{
for (int input_index : tv::KernelLoopX<int>(num_indices_in)) {{
{
self
.
dtype_indices
}
output_coord_offset = indice_pairs_bwd_filter[input_index];
auto output_coord_offset = indice_pairs_uniq_before_sort_filter[input_index];
if (output_coord_offset > -1){{
if (output_coord_offset != std::numeric_limits<typename TTable::key_type>::max()){{
auto ptr = table.lookup_ptr(output_coord_offset);
if (ptr){{
auto table_offset = table.lookup_offset(output_coord_offset);
auto output_index = ptr->second;
if (table_offset != -1){{
auto output_index = table.value_ptr()[table_offset];
atomicOr(mask_fwd + output_index, filter_mask_fwd);
atomicOr(mask_fwd + output_index, filter_mask_fwd);
// atomicOr(mask_bwd + input_index, filter_mask_bwd);
// atomicOr(mask_bwd + input_index, filter_mask_bwd);
indice_pairs_fwd_filter[output_index] = input_index;
indice_pairs_fwd_filter[output_index] = input_index;
...
@@ -465,11 +475,14 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -465,11 +475,14 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
def
calc_conv_indices_stage2_inference_mask
(
self
):
def
calc_conv_indices_stage2_inference_mask
(
self
):
code
=
pccm
.
FunctionCode
()
code
=
pccm
.
FunctionCode
()
code
.
targ
(
"TTable"
)
code
.
targ
(
"TTable"
)
code
.
arg
(
"table"
,
f
"TTable"
)
# [N, ndim + 1]
code
.
arg
(
"table"
,
f
"TTable"
)
# [N, ndim + 1]
code
.
arg
(
"indice_pairs_fwd"
,
code
.
arg
(
"indice_pairs_fwd"
,
f
"int*"
)
# [kernelProd, MaxSize], inp -> out
f
"int*"
)
# [kernelProd, MaxSize], inp -> out
code
.
arg
(
"indice_pairs_bwd"
,
code
.
arg
(
"indice_pairs_bwd"
,
f
"int*"
)
# [kernelProd, MaxSize], out -> inp
f
"int*"
)
# [kernelProd, MaxSize], out -> inp
code
.
arg
(
"indice_pairs_uniq_before_sort"
,
f
"const typename TTable::key_type*"
)
# [kernelProd, MaxSize]
code
.
arg
(
"mask_fwd"
,
f
"uint32_t*"
)
# [kernelProd]
code
.
arg
(
"mask_fwd"
,
f
"uint32_t*"
)
# [kernelProd]
code
.
arg
(
"num_indices_in"
,
"int"
)
code
.
arg
(
"num_indices_in"
,
"int"
)
code
.
arg
(
"num_indices_out"
,
"int"
)
code
.
arg
(
"num_indices_out"
,
"int"
)
...
@@ -481,12 +494,13 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -481,12 +494,13 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
auto indice_pairs_fwd_filter = indice_pairs_fwd + filter_offset * num_indices_out;
auto indice_pairs_fwd_filter = indice_pairs_fwd + filter_offset * num_indices_out;
auto indice_pairs_bwd_filter = indice_pairs_bwd + filter_offset * num_indices_in;
auto indice_pairs_bwd_filter = indice_pairs_bwd + filter_offset * num_indices_in;
auto indice_pairs_uniq_before_sort_filter = indice_pairs_uniq_before_sort + filter_offset * num_indices_in;
for (int input_index : tv::KernelLoopX<int>(num_indices_in)) {{
for (int input_index : tv::KernelLoopX<int>(num_indices_in)) {{
{
self
.
dtype_indices
}
output_coord_offset = indice_pairs_
bwd
_filter[input_index];
auto
output_coord_offset = indice_pairs_
uniq_before_sort
_filter[input_index];
if (output_coord_offset
> -1
){{
if (output_coord_offset
!= std::numeric_limits<typename TTable::key_type>::max()
){{
auto
ptr
= table.lookup_
ptr
(output_coord_offset);
auto
table_offset
= table.lookup_
offset
(output_coord_offset);
if (
ptr
){{
if (
table_offset != -1
){{
auto output_index =
ptr->second
;
auto output_index =
table.value_ptr()[table_offset]
;
atomicOr(mask_fwd + output_index, filter_mask_fwd);
atomicOr(mask_fwd + output_index, filter_mask_fwd);
indice_pairs_fwd_filter[output_index] = input_index;
indice_pairs_fwd_filter[output_index] = input_index;
}}
}}
...
@@ -499,7 +513,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -499,7 +513,6 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
def
build_subm_conv_hash_table
(
self
):
def
build_subm_conv_hash_table
(
self
):
code
=
pccm
.
FunctionCode
()
code
=
pccm
.
FunctionCode
()
code
.
targ
(
"TTable"
)
code
.
targ
(
"TTable"
)
code
.
arg
(
"table"
,
f
"TTable"
)
# [N, ndim + 1]
code
.
arg
(
"table"
,
f
"TTable"
)
# [N, ndim + 1]
code
.
arg
(
"indices_in"
,
f
"const int*"
)
# [N, ndim + 1]
code
.
arg
(
"indices_in"
,
f
"const int*"
)
# [N, ndim + 1]
...
@@ -509,8 +522,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -509,8 +522,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
for (int i : tv::KernelLoopX<int>(num_indices)) {{
for (int i : tv::KernelLoopX<int>(num_indices)) {{
{
self
.
dtype_indices
}
index = layout_npq(indices_in + i *
{
self
.
ndim
+
1
}
);
table.insert(layout_npq(indices_in + i *
{
self
.
ndim
+
1
}
), i);
table.insert(index, i);
}}
}}
"""
)
"""
)
return
code
return
code
...
@@ -518,11 +530,12 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -518,11 +530,12 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@
pccm
.
cuda
.
cuda_global_function
@
pccm
.
cuda
.
cuda_global_function
def
clean_indices_uniq
(
self
):
def
clean_indices_uniq
(
self
):
code
=
pccm
.
FunctionCode
()
code
=
pccm
.
FunctionCode
()
code
.
arg
(
"indice_pairs_for_uniq"
,
f
"
{
self
.
dtype_indices
}
*"
)
code
.
targ
(
"T"
)
code
.
arg
(
"size"
,
f
"
{
self
.
dtype_indices
}
"
)
code
.
arg
(
"indice_pairs_for_uniq"
,
f
"T*"
)
code
.
arg
(
"size"
,
f
"size_t"
)
code
.
raw
(
f
"""
code
.
raw
(
f
"""
for (
{
self
.
dtype_indices
}
i : tv::KernelLoopX<
{
self
.
dtype_indices
}
>(size)) {{
for (
size_t
i : tv::KernelLoopX<
size_t
>(size)) {{
indice_pairs_for_uniq[i] = std::numeric_limits<
{
self
.
dtype_indices
}
>::max();
indice_pairs_for_uniq[i] = std::numeric_limits<
T
>::max();
}}
}}
"""
)
"""
)
return
code
return
code
...
@@ -559,13 +572,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -559,13 +572,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
for (int i : tv::KernelLoopX<int>(num_indices_in)) {{
for (int i : tv::KernelLoopX<int>(num_indices_in)) {{
tv::array<int,
{
self
.
ndim
+
1
}
> npq_offset;
tv::array<int,
{
self
.
ndim
+
1
}
> npq_offset;
if (loc_iter.query_npq_no_stride(indices_in + i *
{
self
.
ndim
+
1
}
, npq_offset)){{
if (loc_iter.query_npq_no_stride(indices_in + i *
{
self
.
ndim
+
1
}
, npq_offset)){{
{
self
.
dtype_indices
}
offset = loc_iter.layout_npq(npq_offset);
auto offset = loc_iter.layout_npq(npq_offset);
auto item = table.lookup(offset); // performance bound
// auto item = table.lookup(offset); // performance bound
if (!item.empty()){{
auto table_offset = table.lookup_offset(offset); // performance bound
if (table_offset != -1){{
auto v = table.value_ptr()[table_offset];
int old_num = tv::cuda::atomicAggInc(indice_num_per_loc + filter_offset);
int old_num = tv::cuda::atomicAggInc(indice_num_per_loc + filter_offset);
indice_pairs[filter_offset_mul_indices_pair_size + old_num] = i;
indice_pairs[filter_offset_mul_indices_pair_size + old_num] = i;
indice_pairs[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size + old_num] =
item.second
;
indice_pairs[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size + old_num] =
v
;
indice_pairs[filter_offset_mul_indices_pair_size_1 + old_num] =
item.second
;
indice_pairs[filter_offset_mul_indices_pair_size_1 + old_num] =
v
;
indice_pairs[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size_1 + old_num] = i;
indice_pairs[indices_pair_size_mul_RS + filter_offset_mul_indices_pair_size_1 + old_num] = i;
}}
}}
}}
}}
...
@@ -613,10 +628,11 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -613,10 +628,11 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
tv::array<int,
{
self
.
ndim
+
1
}
> nhw_offset;
tv::array<int,
{
self
.
ndim
+
1
}
> nhw_offset;
// table: input indice coord to output index (or output indice coord to input index)
// table: input indice coord to output index (or output indice coord to input index)
if (loc_iter.query_nhw(indices_in + output_index *
{
self
.
ndim
+
1
}
, nhw_offset)){{
if (loc_iter.query_nhw(indices_in + output_index *
{
self
.
ndim
+
1
}
, nhw_offset)){{
{
self
.
dtype_indices
}
offset = loc_iter.layout_npq(nhw_offset);
auto offset = loc_iter.layout_npq(nhw_offset);
auto item = table.lookup(offset);
// auto item = table.lookup(offset);
if (!item.empty()) {{
auto table_offset = table.lookup_offset(offset); // performance bound
auto input_index = item.second; // we find a input indice idx.
if (table_offset != -1){{
auto input_index = table.value_ptr()[table_offset]; // we find a input indice idx.
atomicOr(mask + output_index, filter_mask_out);
atomicOr(mask + output_index, filter_mask_out);
atomicOr(mask + input_index, filter_mask_in);
atomicOr(mask + input_index, filter_mask_in);
// for this output, we set correct input idx.
// for this output, we set correct input idx.
...
@@ -670,10 +686,10 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -670,10 +686,10 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
tv::array<int,
{
self
.
ndim
+
1
}
> nhw_offset;
tv::array<int,
{
self
.
ndim
+
1
}
> nhw_offset;
// table: input indice coord to output index (or output indice coord to input index)
// table: input indice coord to output index (or output indice coord to input index)
if (loc_iter.query_nhw(indices_in + output_index *
{
self
.
ndim
+
1
}
, nhw_offset)){{
if (loc_iter.query_nhw(indices_in + output_index *
{
self
.
ndim
+
1
}
, nhw_offset)){{
{
self
.
dtype_indices
}
offset = loc_iter.layout_npq(nhw_offset);
auto
offset = loc_iter.layout_npq(nhw_offset);
auto
item
= table.lookup(offset);
auto
table_offset
= table.lookup
_offset
(offset);
// performance bound
if (
!item.empty())
{{
if (
table_offset != -1)
{{
auto input_index =
item.second
; // we find a input indice idx.
auto input_index =
table.value_ptr()[table_offset]
; // we find a input indice idx.
atomicOr(mask1 + output_index, filter_mask_out);
atomicOr(mask1 + output_index, filter_mask_out);
atomicOr(mask2 + input_index, filter_mask_in);
atomicOr(mask2 + input_index, filter_mask_in);
// for this output, we set correct input idx.
// for this output, we set correct input idx.
...
@@ -706,10 +722,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -706,10 +722,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
// TODO stream
// TODO stream
// TODO handle num input == 0
// TODO handle num input == 0
int kv = tv::arrayops::prod(
ksize
);
int kv =
ksize.op<
tv::arrayops::prod
>
();
TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error");
TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error");
TV_ASSERT_RT_ERR(tv::arrayops::prod(input_dims) <= std::numeric_limits<
{
self
.
dtype_indices
}
>::max(),
"kernel volume must smaller than max value of
{
self
.
dtype_indices
}
");
// indice_pairs: [2, kv, indices.dim(0)]
// indice_pairs: [2, kv, indices.dim(0)]
// indice_pairs_uniq: [indice_pairs.size() / 2 + 1]
// indice_pairs_uniq: [indice_pairs.size() / 2 + 1]
tv::check_shape(indice_pairs, {{2, kv, indices.dim(0)}});
tv::check_shape(indice_pairs, {{2, kv, indices.dim(0)}});
...
@@ -724,11 +738,17 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -724,11 +738,17 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem);
ConvLocIter loc_iter(problem);
tv::cuda::Launch launcher_clean_uniq(uniq_size, reinterpret_cast<cudaStream_t>(stream_int));
tv::cuda::Launch launcher_clean_uniq(uniq_size, reinterpret_cast<cudaStream_t>(stream_int));
launcher_clean_uniq(clean_indices_uniq, indice_pairs_uniq.data_ptr<
{
self
.
dtype_indices
}
>(), uniq_size);
launcher_num_act_in(calc_conv_indices_stage1, loc_iter, indices.data_ptr<const int>(),
tv::dispatch<int32_t, int64_t>(indice_pairs_uniq.dtype(), [&](auto I){{
indice_pairs.data_ptr<
{
self
.
dtype_indices
}
>(),
using T = TV_DECLTYPE(I);
indice_pairs_uniq.data_ptr<
{
self
.
dtype_indices
}
>(), indice_num_per_loc.data_ptr<int>(), indices.dim(0),
TV_ASSERT_RT_ERR(input_dims.op<tv::arrayops::prod>() < std::numeric_limits<T>::max(),
indice_pairs.dim(2), kv, transposed);
"kernel volume must smaller than max value of T");
launcher_clean_uniq(clean_indices_uniq<T>, indice_pairs_uniq.data_ptr<T>(), uniq_size);
launcher_num_act_in(calc_conv_indices_stage1<T>, loc_iter, indices.data_ptr<const int>(),
indice_pairs.data_ptr<
{
self
.
dtype_indices
}
>(),
indice_pairs_uniq.data_ptr<T>(), indice_num_per_loc.data_ptr<int>(), indices.dim(0),
indice_pairs.dim(2), kv, transposed);
}});
// thrust::device_ptr<
{
self
.
dtype_indices
}
> ptr_tr(indice_pairs_uniq.data_ptr<
{
self
.
dtype_indices
}
>());
// thrust::device_ptr<
{
self
.
dtype_indices
}
> ptr_tr(indice_pairs_uniq.data_ptr<
{
self
.
dtype_indices
}
>());
// auto thrust_ctx = thrust::cuda::par.on(reinterpret_cast<cudaStream_t>(stream_int));
// auto thrust_ctx = thrust::cuda::par.on(reinterpret_cast<cudaStream_t>(stream_int));
// thrust::sort(thrust_ctx, ptr_tr, ptr_tr + uniq_size);
// thrust::sort(thrust_ctx, ptr_tr, ptr_tr + uniq_size);
...
@@ -745,11 +765,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -745,11 +765,15 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code
.
arg
(
"uniq_size"
,
"int64_t"
)
code
.
arg
(
"uniq_size"
,
"int64_t"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
)
code
.
raw
(
f
"""
code
.
raw
(
f
"""
thrust::device_ptr<
{
self
.
dtype_indices
}
> ptr_tr(indice_pairs_uniq.data_ptr<
{
self
.
dtype_indices
}
>());
int num_out_act = 0;
auto thrust_ctx = thrust::cuda::par.on(reinterpret_cast<cudaStream_t>(stream_int));
tv::dispatch<int32_t, int64_t>(indice_pairs_uniq.dtype(), [&](auto I){{
thrust::sort(thrust_ctx, ptr_tr, ptr_tr + uniq_size);
using T = TV_DECLTYPE(I);
auto new_end = thrust::unique(thrust_ctx, ptr_tr, ptr_tr + uniq_size);
thrust::device_ptr<T> ptr_tr(indice_pairs_uniq.data_ptr<T>());
auto num_out_act = new_end - ptr_tr - 1;
auto thrust_ctx = thrust::cuda::par.on(reinterpret_cast<cudaStream_t>(stream_int));
thrust::sort(thrust_ctx, ptr_tr, ptr_tr + uniq_size);
auto new_end = thrust::unique(thrust_ctx, ptr_tr, ptr_tr + uniq_size);
num_out_act = new_end - ptr_tr - 1;
}});
return num_out_act;
return num_out_act;
"""
)
"""
)
return
code
.
ret
(
"int"
)
return
code
.
ret
(
"int"
)
...
@@ -757,8 +781,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -757,8 +781,8 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@
pccm
.
cuda
.
static_function
@
pccm
.
cuda
.
static_function
def
generate_conv_inds_stage2
(
self
):
def
generate_conv_inds_stage2
(
self
):
code
=
pccm
.
FunctionCode
()
code
=
pccm
.
FunctionCode
()
code
.
arg
(
"indices, hashdata"
,
"tv::Tensor"
)
code
.
arg
(
"indices, hashdata
_k, hashdata_v
"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pairs, indice_pairs_uniq, out_inds"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pairs, indice_pairs_uniq,
indice_pairs_uniq_before_sort,
out_inds"
,
"tv::Tensor"
)
code
.
arg
(
"num_out_act"
,
"int"
)
code
.
arg
(
"num_out_act"
,
"int"
)
code
.
arg
(
"batch_size"
,
"int"
)
code
.
arg
(
"batch_size"
,
"int"
)
code
.
arg
(
"output_dims, input_dims"
,
f
"tv::array<int,
{
self
.
ndim
}
>"
)
code
.
arg
(
"output_dims, input_dims"
,
f
"tv::array<int,
{
self
.
ndim
}
>"
)
...
@@ -770,8 +794,11 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -770,8 +794,11 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
auto custream = reinterpret_cast<cudaStream_t>(stream_int);
auto custream = reinterpret_cast<cudaStream_t>(stream_int);
// TODO stream
// TODO stream
// TODO handle num input == 0
// TODO handle num input == 0
int kv = tv::arrayops::prod(
ksize
);
int kv =
ksize.op<
tv::arrayops::prod
>
();
TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error");
TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error");
TV_ASSERT_RT_ERR(hashdata_k.dtype() == indice_pairs_uniq.dtype(), "error");
TV_ASSERT_RT_ERR(hashdata_v.dtype() == tv::int32, "error");
// indice_pairs: [2, kv, indices.dim(0)]
// indice_pairs: [2, kv, indices.dim(0)]
// indice_pairs_uniq: [indice_pairs.size() / 2 + 1]
// indice_pairs_uniq: [indice_pairs.size() / 2 + 1]
// out_inds: [MaxSize,
{
self
.
ndim
+
1
}
]
// out_inds: [MaxSize,
{
self
.
ndim
+
1
}
]
...
@@ -787,22 +814,25 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -787,22 +814,25 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
// TODO handle invalid num_out_act
// TODO handle invalid num_out_act
indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_out_act);
indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_out_act);
tv::cuda::Launch lanucher_build_hash(num_out_act, custream);
tv::cuda::Launch lanucher_build_hash(num_out_act, custream);
using V =
{
self
.
dtype_indices
}
;
tv::dispatch<int32_t, int64_t>(hashdata_k.dtype(), [&](auto I){{
using KeyType =
{
self
.
dtype_indices
}
;
using V =
{
self
.
dtype_indices
}
;
constexpr KeyType kEmptyKey = std::numeric_limits<KeyType>::max();
using K = TV_DECLTYPE(I);
using table_t =
using table_t =
tv::hash::LinearHashTable<KeyType, V, tv::hash::Murmur3Hash<KeyType>,
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
kEmptyKey, false>;
tv::hash::default_empty_key_v<K>, false>;
using pair_t = typename table_t::value_type;
TV_ASSERT_RT_ERR(hashdata_k.dim(0) >= num_out_act, "hash size not enough");
TV_ASSERT_RT_ERR(hashdata.dim(0) >= num_out_act, "hash size not enough");
table_t hash = table_t(hashdata_k.data_ptr<K>(), hashdata_v.data_ptr<V>(), hashdata_k.dim(0));
table_t hash = table_t(hashdata.data_ptr<pair_t>(), hashdata.dim(0));
tv::hash::clear_map_split(hash, custream);
hash.clear(custream);
// hash.clear(custream);
lanucher_build_hash(build_conv_hash_table<table_t>, hash,
lanucher_build_hash(build_conv_hash_table<table_t>, hash,
out_inds.data_ptr<int>(), indice_pairs_uniq.data_ptr<const
{
self
.
dtype_indices
}
>(),
out_inds.data_ptr<int>(), indice_pairs_uniq.data_ptr<const K>(),
loc_iter.layout_npq, num_out_act);
loc_iter.layout_npq, num_out_act);
launcher_num_act_in(calc_conv_indices_stage2<table_t>, hash,
launcher_num_act_in(calc_conv_indices_stage2<table_t>, hash,
indice_pairs[1].data_ptr<int>(), indices.dim(0),
indice_pairs_uniq_before_sort.data_ptr<const K>(),
indice_pairs.dim(2));
indice_pairs[1].data_ptr<int>(),
indices.dim(0),
indice_pairs.dim(2));
}});
return num_out_act;
return num_out_act;
"""
)
"""
)
return
code
.
ret
(
"int"
)
return
code
.
ret
(
"int"
)
...
@@ -824,9 +854,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -824,9 +854,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
// TODO stream
// TODO stream
// TODO handle num input == 0
// TODO handle num input == 0
int kv = tv::arrayops::prod(ksize);
int kv = ksize.op<tv::arrayops::prod>();
TV_ASSERT_RT_ERR(tv::arrayops::prod(input_dims) <= std::numeric_limits<
{
self
.
dtype_indices
}
>::max(),
"kernel volume must smaller than max value of
{
self
.
dtype_indices
}
");
// indice_pairs_bwd: [kv, indices.dim(0)]
// indice_pairs_bwd: [kv, indices.dim(0)]
// indice_pairs_uniq: [indice_pairs_bwd.size() + 1]
// indice_pairs_uniq: [indice_pairs_bwd.size() + 1]
tv::check_shape(indice_pairs_bwd, {{kv, indices.dim(0)}});
tv::check_shape(indice_pairs_bwd, {{kv, indices.dim(0)}});
...
@@ -842,20 +870,25 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -842,20 +870,25 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem);
ConvLocIter loc_iter(problem);
tv::cuda::Launch launcher_clean_uniq(uniq_size, reinterpret_cast<cudaStream_t>(stream_int));
tv::cuda::Launch launcher_clean_uniq(uniq_size, reinterpret_cast<cudaStream_t>(stream_int));
launcher_clean_uniq(clean_indices_uniq, indice_pairs_uniq.data_ptr<
{
self
.
dtype_indices
}
>(), uniq_size);
tv::dispatch<int32_t, int64_t>(indice_pairs_uniq.dtype(), [&](auto I){{
launcher_num_act_in(calc_conv_indices_stage1_mask, loc_iter, indices.data_ptr<const int>(),
using T = TV_DECLTYPE(I);
indice_pairs_bwd.data_ptr<
{
self
.
dtype_indices
}
>(),
TV_ASSERT_RT_ERR(input_dims.op<tv::arrayops::prod>() < std::numeric_limits<T>::max(),
indice_pairs_uniq.data_ptr<
{
self
.
dtype_indices
}
>(), indice_num_per_loc.data_ptr<int>(), indices.dim(0),
"kernel volume must smaller than max value of T");
kv, transposed);
launcher_clean_uniq(clean_indices_uniq<T>, indice_pairs_uniq.data_ptr<T>(), uniq_size);
launcher_num_act_in(calc_conv_indices_stage1_mask<T>, loc_iter, indices.data_ptr<const int>(),
indice_pairs_bwd.data_ptr<
{
self
.
dtype_indices
}
>(),
indice_pairs_uniq.data_ptr<T>(), indice_num_per_loc.data_ptr<int>(), indices.dim(0),
kv, transposed);
}});
"""
)
"""
)
return
code
# .ret("int")
return
code
# .ret("int")
@
pccm
.
cuda
.
static_function
@
pccm
.
cuda
.
static_function
def
generate_conv_inds_stage2_mask
(
self
):
def
generate_conv_inds_stage2_mask
(
self
):
code
=
pccm
.
FunctionCode
()
code
=
pccm
.
FunctionCode
()
code
.
arg
(
"indices, hashdata"
,
"tv::Tensor"
)
code
.
arg
(
"indices, hashdata
_k, hashdata_v
"
,
"tv::Tensor"
)
code
.
arg
(
code
.
arg
(
"indice_pairs_fwd, indice_pairs_bwd, indice_pairs_uniq, out_inds"
,
"indice_pairs_fwd, indice_pairs_bwd, indice_pairs_uniq,
indice_pairs_uniq_before_sort,
out_inds"
,
"tv::Tensor"
)
"tv::Tensor"
)
code
.
arg
(
"mask_fwd, mask_bwd"
,
"tv::Tensor"
)
code
.
arg
(
"mask_fwd, mask_bwd"
,
"tv::Tensor"
)
...
@@ -870,12 +903,13 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -870,12 +903,13 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
auto custream = reinterpret_cast<cudaStream_t>(stream_int);
auto custream = reinterpret_cast<cudaStream_t>(stream_int);
// TODO stream
// TODO stream
// TODO handle num input == 0
// TODO handle num input == 0
int kv = tv::arrayops::prod(
ksize
);
int kv =
ksize.op<
tv::arrayops::prod
>
();
// indice_pairs_bwd: [kv, indices.dim(0)]
// indice_pairs_bwd: [kv, indices.dim(0)]
// indice_pairs_fwd: [kv, out_inds.dim(0)]
// indice_pairs_fwd: [kv, out_inds.dim(0)]
auto ctx = tv::Context();
auto ctx = tv::Context();
ctx.set_cuda_stream(custream);
ctx.set_cuda_stream(custream);
TV_ASSERT_RT_ERR(hashdata_k.dtype() == indice_pairs_uniq.dtype(), "error");
TV_ASSERT_RT_ERR(hashdata_v.dtype() == tv::int32, "error");
// out_inds: [MaxSize,
{
self
.
ndim
+
1
}
]
// out_inds: [MaxSize,
{
self
.
ndim
+
1
}
]
// auto timer = tv::CudaContextTimer<>();
// auto timer = tv::CudaContextTimer<>();
tv::check_shape(indice_pairs_bwd, {{kv, indices.dim(0)}});
tv::check_shape(indice_pairs_bwd, {{kv, indices.dim(0)}});
...
@@ -892,45 +926,48 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -892,45 +926,48 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
// TODO handle invalid num_out_act
// TODO handle invalid num_out_act
indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_out_act);
indice_pairs_uniq = indice_pairs_uniq.slice_first_axis(0, num_out_act);
tv::cuda::Launch lanucher_build_hash(num_out_act, custream);
tv::cuda::Launch lanucher_build_hash(num_out_act, custream);
using V =
{
self
.
dtype_indices
}
;
tv::dispatch<int32_t, int64_t>(hashdata_k.dtype(), [&](auto I){{
using KeyType =
{
self
.
dtype_indices
}
;
using V =
{
self
.
dtype_indices
}
;
constexpr KeyType kEmptyKey = std::numeric_limits<KeyType>::max();
using K = TV_DECLTYPE(I);
using table_t =
using table_t =
tv::hash::LinearHashTable<KeyType, V, tv::hash::Murmur3Hash<KeyType>,
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
kEmptyKey, false>;
tv::hash::default_empty_key_v<K>, false>;
using pair_t = typename table_t::value_type;
TV_ASSERT_RT_ERR(hashdata_k.dim(0) >= num_out_act, "hash size not enough");
TV_ASSERT_RT_ERR(hashdata.dim(0) >= num_out_act, "hash size not enough");
table_t hash = table_t(hashdata_k.data_ptr<K>(), hashdata_v.data_ptr<V>(), hashdata_k.dim(0));
table_t hash = table_t(hashdata.data_ptr<pair_t>(), hashdata.dim(0));
tv::hash::clear_map_split(hash, custream);
hash.clear(custream);
lanucher_build_hash(build_conv_hash_table<table_t>, hash,
lanucher_build_hash(build_conv_hash_table<table_t>, hash,
out_inds.data_ptr<int>(), indice_pairs_uniq.data_ptr<const
{
self
.
dtype_indices
}
>(),
out_inds.data_ptr<int>(), indice_pairs_uniq.data_ptr<const K>(),
loc_iter.layout_npq, num_out_act);
loc_iter.layout_npq, num_out_act);
if (!mask_bwd.empty()){{
if (!mask_bwd.empty()){{
// auto timer = tv::CudaContextTimer<>();
// auto timer = tv::CudaContextTimer<>();
launcher_num_act_in(calc_conv_indices_stage2_mask<table_t>, hash,
launcher_num_act_in(calc_conv_indices_stage2_mask<table_t>, hash,
indice_pairs_fwd.data_ptr<int>(), indice_pairs_bwd.data_ptr<int>(),
indice_pairs_fwd.data_ptr<int>(), indice_pairs_bwd.data_ptr<int>(),
mask_fwd.data_ptr<uint32_t>(), mask_bwd.data_ptr<uint32_t>(),
indice_pairs_uniq_before_sort.data_ptr<K>(),
indice_pairs_bwd.dim(1), indice_pairs_fwd.dim(1));
mask_fwd.data_ptr<uint32_t>(), mask_bwd.data_ptr<uint32_t>(),
// tv::ssprint("calc_conv_indices_stage2_mask", timer.report() / 1000.0);
indice_pairs_bwd.dim(1), indice_pairs_fwd.dim(1));
launcher_num_act_in_no_y(calc_conv_indices_stage2_mask_output, indice_pairs_bwd.data_ptr<int>(),
// tv::ssprint("calc_conv_indices_stage2_mask", timer.report() / 1000.0);
mask_bwd.data_ptr<uint32_t>(),
launcher_num_act_in_no_y(calc_conv_indices_stage2_mask_output, indice_pairs_bwd.data_ptr<int>(),
indice_pairs_bwd.dim(1), kv);
mask_bwd.data_ptr<uint32_t>(),
// tv::ssprint("calc_conv_indices_stage2_mask_output", timer.report() / 1000.0);
indice_pairs_bwd.dim(1), kv);
if (mask_fwd.dim(0) == 2){{
// tv::ssprint("calc_conv_indices_stage2_mask_output", timer.report() / 1000.0);
mask_fwd[1].copy_(mask_fwd[0], ctx);
if (mask_fwd.dim(0) == 2){{
}}
mask_fwd[1].copy_(mask_fwd[0], ctx);
if (mask_bwd.dim(0) == 2){{
}}
mask_bwd[1].copy_(mask_bwd[0], ctx);
if (mask_bwd.dim(0) == 2){{
}}
mask_bwd[1].copy_(mask_bwd[0], ctx);
}}else{{
}}
launcher_num_act_in(calc_conv_indices_stage2_inference_mask<table_t>, hash,
}}else{{
indice_pairs_fwd.data_ptr<int>(), indice_pairs_bwd.data_ptr<int>(),
launcher_num_act_in(calc_conv_indices_stage2_inference_mask<table_t>, hash,
mask_fwd.data_ptr<uint32_t>(),
indice_pairs_fwd.data_ptr<int>(), indice_pairs_bwd.data_ptr<int>(),
indice_pairs_bwd.dim(1), indice_pairs_fwd.dim(1));
indice_pairs_uniq_before_sort.data_ptr<K>(),
if (mask_fwd.dim(0) == 2){{
mask_fwd.data_ptr<uint32_t>(),
mask_fwd[1].copy_(mask_fwd[0], ctx);
indice_pairs_bwd.dim(1), indice_pairs_fwd.dim(1));
if (mask_fwd.dim(0) == 2){{
mask_fwd[1].copy_(mask_fwd[0], ctx);
}}
}}
}}
}}
}}
);
return num_out_act;
return num_out_act;
"""
)
"""
)
return
code
.
ret
(
"int"
)
return
code
.
ret
(
"int"
)
...
@@ -938,7 +975,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -938,7 +975,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@
pccm
.
cuda
.
static_function
@
pccm
.
cuda
.
static_function
def
generate_subm_conv_inds
(
self
):
def
generate_subm_conv_inds
(
self
):
code
=
pccm
.
FunctionCode
()
code
=
pccm
.
FunctionCode
()
code
.
arg
(
"indices, hashdata"
,
"tv::Tensor"
)
code
.
arg
(
"indices, hashdata
_k, hashdata_v
"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pairs, out_inds, indice_num_per_loc"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pairs, out_inds, indice_num_per_loc"
,
"tv::Tensor"
)
code
.
arg
(
"batch_size"
,
"int"
)
code
.
arg
(
"batch_size"
,
"int"
)
code
.
arg
(
"input_dims"
,
f
"tv::array<int,
{
self
.
ndim
}
>"
)
code
.
arg
(
"input_dims"
,
f
"tv::array<int,
{
self
.
ndim
}
>"
)
...
@@ -953,7 +990,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -953,7 +990,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
auto ctx = tv::Context();
auto ctx = tv::Context();
ctx.set_cuda_stream(custream);
ctx.set_cuda_stream(custream);
if (!indice_pair_mask.empty()){{
if (!indice_pair_mask.empty()){{
TV_ASSERT_INVALID_ARG(tv::arrayops::prod(
ksize
) < 32, "for now only support 32bit mask");
TV_ASSERT_INVALID_ARG(
ksize.op<
tv::arrayops::prod
>
() <
=
32, "for now only support 32bit mask");
}}
}}
// TODO stream
// TODO stream
// TODO handle num input == 0
// TODO handle num input == 0
...
@@ -963,7 +1000,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -963,7 +1000,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
stride[i] = 1;
stride[i] = 1;
padding[i] = (ksize[i] / 2) * dilation[i];
padding[i] = (ksize[i] / 2) * dilation[i];
}}
}}
int kv = tv::arrayops::prod(
ksize
);
int kv =
ksize.op<
tv::arrayops::prod
>
();
TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error");
TV_ASSERT_RT_ERR(kv == indice_pairs.dim(1), "error");
// indice_pairs: [2, kv, indices.dim(0)]
// indice_pairs: [2, kv, indices.dim(0)]
// out_inds: [MaxSize,
{
self
.
ndim
+
1
}
]
// out_inds: [MaxSize,
{
self
.
ndim
+
1
}
]
...
@@ -972,53 +1009,55 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
...
@@ -972,53 +1009,55 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
tv::cuda::Launch launcher_num_act_in(indices.dim(0), custream);
tv::cuda::Launch launcher_num_act_in(indices.dim(0), custream);
launcher_num_act_in.blocks.y = (kv / 2) + 1;
launcher_num_act_in.blocks.y = (kv / 2) + 1;
// launcher_num_act_in.blocks.y = kv;
// launcher_num_act_in.blocks.y = kv;
TV_ASSERT_RT_ERR(tv::arrayops::prod(input_dims) <= std::numeric_limits<
{
self
.
dtype_indices
}
>::max(),
"kernel volume must smaller than max value of
{
self
.
dtype_indices
}
");
ConvProblem problem(batch_size, 1, 1, input_dims, input_dims, ksize, padding, stride, dilation);
ConvProblem problem(batch_size, 1, 1, input_dims, input_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem);
ConvLocIter loc_iter(problem);
tv::cuda::Launch lanucher_build_hash(indices.dim(0), custream);
tv::cuda::Launch lanucher_build_hash(indices.dim(0), custream);
using V =
{
self
.
dtype_indices
}
;
tv::dispatch<int32_t, int64_t>(hashdata_k.dtype(), [&](auto I){{
using KeyType =
{
self
.
dtype_indices
}
;
using V =
{
self
.
dtype_indices
}
;
constexpr KeyType kEmptyKey = std::numeric_limits<KeyType>::max();
using K = TV_DECLTYPE(I);
TV_ASSERT_RT_ERR(input_dims.op<tv::arrayops::prod>() < std::numeric_limits<K>::max(),
using table_t =
"kernel volume must smaller than max value of K");
tv::hash::LinearHashTable<KeyType, V, tv::hash::Murmur3Hash<KeyType>,
kEmptyKey, false>;
using table_t =
using pair_t = typename table_t::value_type;
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
TV_ASSERT_RT_ERR(hashdata.dim(0) >= indices.dim(0), "hash size not enough");
tv::hash::default_empty_key_v<K>, false>;
table_t hash = table_t(hashdata.data_ptr<pair_t>(), hashdata.dim(0));
TV_ASSERT_RT_ERR(hashdata_k.dim(0) >= indices.dim(0), "hash size not enough");
hash.clear(custream);
table_t hash = table_t(hashdata_k.data_ptr<K>(), hashdata_v.data_ptr<V>(), hashdata_k.dim(0));
// tv::ssprint("clear hash time", hashdata.dim(0), timer.report() / 1000.0);
tv::hash::clear_map_split(hash, custream);
lanucher_build_hash(build_subm_conv_hash_table<table_t>, hash, indices.data_ptr<const int>(),
lanucher_build_hash(build_subm_conv_hash_table<table_t>, hash, indices.data_ptr<const int>(),
loc_iter.layout_npq, indices.dim(0));
loc_iter.layout_npq, indices.dim(0));
// tv::ssprint("build_hash time", timer.report() / 1000.0);
// tv::ssprint("build_hash time", timer.report() / 1000.0);
if (!indice_pair_mask.empty()){{
if (!indice_pair_mask.empty()){{
TV_ASSERT_INVALID_ARG(indice_pair_mask.ndim() == 2, "error");
TV_ASSERT_INVALID_ARG(indice_pair_mask.ndim() == 2, "error");
if (indice_pair_mask.dim(0) == 2){{
if (indice_pair_mask.dim(0) == 2){{
auto mask_0 = indice_pair_mask[0];
auto mask_0 = indice_pair_mask[0];
tv::cuda::Launch lanucher_fill(mask_0.size(), custream);
tv::cuda::Launch lanucher_fill(mask_0.size(), custream);
lanucher_fill(cudakers::fill_kernel<uint32_t>, mask_0.data_ptr<uint32_t>(), (1 << (kv / 2)), mask_0.size());
lanucher_fill(cudakers::fill_kernel<uint32_t>, mask_0.data_ptr<uint32_t>(), (1 << (kv / 2)), mask_0.size());
indice_pair_mask[1].zero_(ctx);
indice_pair_mask[1].zero_(ctx);
auto kernel = &calc_subm_conv_indices_split_mask<table_t>;
auto kernel = &calc_subm_conv_indices_split_mask<table_t>;
launcher_num_act_in(kernel, loc_iter, hash,
launcher_num_act_in(kernel, loc_iter, hash,
indices.data_ptr<int>(), indice_pairs.data_ptr<int>(),
indices.data_ptr<int>(), indice_pairs.data_ptr<int>(),
indice_pair_mask[0].data_ptr<uint32_t>(), indice_pair_mask[1].data_ptr<uint32_t>(),
indice_pair_mask[0].data_ptr<uint32_t>(), indice_pair_mask[1].data_ptr<uint32_t>(),
indices.dim(0), indice_pairs.dim(2), kv);
indices.dim(0), indice_pairs.dim(2), kv);
}}else{{
tv::cuda::Launch lanucher_fill(indice_pair_mask.size(), custream);
lanucher_fill(cudakers::fill_kernel<uint32_t>, indice_pair_mask.data_ptr<uint32_t>(), (1 << (kv / 2)), indice_pair_mask.size());
TV_ASSERT_RT_ERR(indice_pair_mask.dim(0) == 1, "error");
launcher_num_act_in(calc_subm_conv_indices_mask<table_t>, loc_iter, hash,
indices.data_ptr<int>(), indice_pairs.data_ptr<int>(),
indice_pair_mask.data_ptr<uint32_t>(), indices.dim(0), indice_pairs.dim(2), kv);
}}
}}else{{
}}else{{
tv::cuda::Launch lanucher_fill(indice_pair_mask.size(), custream);
launcher_num_act_in(calc_subm_conv_indices<table_t>, loc_iter, hash, indices.data_ptr<int>(),
lanucher_fill(cudakers::fill_kernel<uint32_t>, indice_pair_mask.data_ptr<uint32_t>(), (1 << (kv / 2)), indice_pair_mask.size());
indice_pairs.data_ptr<int>(),
TV_ASSERT_RT_ERR(indice_pair_mask.dim(0) == 1, "error");
indice_num_per_loc.data_ptr<int>(), indices.dim(0), indice_pairs.dim(2), kv);
launcher_num_act_in(calc_subm_conv_indices_mask<table_t>, loc_iter, hash,
indices.data_ptr<int>(), indice_pairs.data_ptr<int>(),
indice_pair_mask.data_ptr<uint32_t>(), indices.dim(0), indice_pairs.dim(2), kv);
}}
}}
}}else{{
launcher_num_act_in(calc_subm_conv_indices<table_t>, loc_iter, hash, indices.data_ptr<int>(),
}});
indice_pairs.data_ptr<int>(),
// tv::ssprint("clear hash time", hashdata.dim(0), timer.report() / 1000.0);
indice_num_per_loc.data_ptr<int>(), indices.dim(0), indice_pairs.dim(2), kv);
}}
// tv::ssprint("gem subm conv inds time", timer.report() / 1000.0);
// tv::ssprint("gem subm conv inds time", timer.report() / 1000.0);
return indices.dim(0);
return indices.dim(0);
"""
)
"""
)
...
@@ -1057,8 +1096,9 @@ class SparseConvIndicesCPU(pccm.ParameterizedClass):
...
@@ -1057,8 +1096,9 @@ class SparseConvIndicesCPU(pccm.ParameterizedClass):
stride[i] = 1;
stride[i] = 1;
padding[i] = (ksize[i] / 2) * dilation[i];
padding[i] = (ksize[i] / 2) * dilation[i];
}}
}}
int kv = tv::arrayops::prod(ksize);
int kv = ksize.op<tv::arrayops::prod>();
TV_ASSERT_RT_ERR(input_dims.op<tv::arrayops::prod>() < std::numeric_limits<
{
self
.
dtype_indices
}
>::max(),
"kernel volume must smaller than max value of
{
self
.
dtype_indices
}
");
ConvProblem problem(batch_size, 1, 1, input_dims, input_dims, ksize, padding, stride, dilation);
ConvProblem problem(batch_size, 1, 1, input_dims, input_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem);
ConvLocIter loc_iter(problem);
int indices_pair_size = indice_pairs.dim(2);
int indices_pair_size = indice_pairs.dim(2);
...
@@ -1116,7 +1156,7 @@ class SparseConvIndicesCPU(pccm.ParameterizedClass):
...
@@ -1116,7 +1156,7 @@ class SparseConvIndicesCPU(pccm.ParameterizedClass):
f
"tv::array<int,
{
self
.
ndim
}
>"
)
f
"tv::array<int,
{
self
.
ndim
}
>"
)
code
.
arg
(
"transposed"
,
f
"bool"
,
"false"
)
code
.
arg
(
"transposed"
,
f
"bool"
,
"false"
)
code
.
raw
(
f
"""
code
.
raw
(
f
"""
int kv = tv::arrayops::prod(
ksize
);
int kv =
ksize.op<
tv::arrayops::prod
>
();
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvProblem problem(batch_size, 1, 1, input_dims, output_dims, ksize, padding, stride, dilation);
ConvLocIter loc_iter(problem);
ConvLocIter loc_iter(problem);
int indices_pair_size = indice_pairs.dim(2);
int indices_pair_size = indice_pairs.dim(2);
...
@@ -1125,6 +1165,8 @@ class SparseConvIndicesCPU(pccm.ParameterizedClass):
...
@@ -1125,6 +1165,8 @@ class SparseConvIndicesCPU(pccm.ParameterizedClass):
std::unordered_map<
{
self
.
dtype_indices
}
,
{
self
.
dtype_indices
}
> hash;
std::unordered_map<
{
self
.
dtype_indices
}
,
{
self
.
dtype_indices
}
> hash;
auto indices_ptr = indices.data_ptr<
{
self
.
dtype_indices
}
>();
auto indices_ptr = indices.data_ptr<
{
self
.
dtype_indices
}
>();
auto out_inds_ptr = out_inds.data_ptr<
{
self
.
dtype_indices
}
>();
auto out_inds_ptr = out_inds.data_ptr<
{
self
.
dtype_indices
}
>();
TV_ASSERT_RT_ERR(input_dims.op<tv::arrayops::prod>() < std::numeric_limits<
{
self
.
dtype_indices
}
>::max(),
"kernel volume must smaller than max value of
{
self
.
dtype_indices
}
");
int indice_in_num = indices.dim(0);
int indice_in_num = indices.dim(0);
int num_act = 0;
int num_act = 0;
...
...
spconv/csrc/sparse/maxpool.py
View file @
7af751dc
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
# limitations under the License.
# limitations under the License.
import
contextlib
import
contextlib
from
cumm.conv.bases
import
ConvEnum
from
cumm.gemm.core.metaarray
import
MetaArray
,
seq
from
cumm.gemm.core.metaarray
import
MetaArray
,
seq
from
cumm
import
dtypes
from
cumm
import
dtypes
import
pccm
import
pccm
...
@@ -202,14 +201,14 @@ class IndiceMaxPool(pccm.Class):
...
@@ -202,14 +201,14 @@ class IndiceMaxPool(pccm.Class):
// if a value is found, other value won't be executed.
// if a value is found, other value won't be executed.
int NumFeatures = TV_DECLTYPE(V)::value;
int NumFeatures = TV_DECLTYPE(V)::value;
int Num0 = MaxThreads / NumFeatures;
int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0));
dim3 blocks(tv::div_up(out.dim(1),
int64_t(
NumFeatures)
)
, tv::div_up(nhot,
int64_t(
Num0))
)
;
dim3 threads(NumFeatures, Num0);
dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
}});
}});
if (!found){{
if (!found){{
int NumFeatures = 16;
int NumFeatures = 16;
int Num0 = MaxThreads / NumFeatures;
int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0));
dim3 blocks(tv::div_up(out.dim(1),
int64_t(
NumFeatures)
)
, tv::div_up(nhot,
int64_t(
Num0))
)
;
dim3 threads(NumFeatures, Num0);
dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
}}
}}
...
@@ -244,14 +243,14 @@ class IndiceMaxPool(pccm.Class):
...
@@ -244,14 +243,14 @@ class IndiceMaxPool(pccm.Class):
// if a value is found, other value won't be executed.
// if a value is found, other value won't be executed.
int NumFeatures = TV_DECLTYPE(V)::value;
int NumFeatures = TV_DECLTYPE(V)::value;
int Num0 = MaxThreads / NumFeatures;
int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0));
dim3 blocks(tv::div_up(out.dim(1),
int64_t(
NumFeatures)
)
, tv::div_up(nhot,
int64_t(
Num0))
)
;
dim3 threads(NumFeatures, Num0);
dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
}});
}});
if (!found){{
if (!found){{
int NumFeatures = 16;
int NumFeatures = 16;
int Num0 = MaxThreads / NumFeatures;
int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0));
dim3 blocks(tv::div_up(out.dim(1),
int64_t(
NumFeatures)
)
, tv::div_up(nhot,
int64_t(
Num0))
)
;
dim3 threads(NumFeatures, Num0);
dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
}}
}}
...
@@ -287,14 +286,14 @@ class IndiceMaxPool(pccm.Class):
...
@@ -287,14 +286,14 @@ class IndiceMaxPool(pccm.Class):
// if a value is found, other value won't be executed.
// if a value is found, other value won't be executed.
int NumFeatures = TV_DECLTYPE(V)::value;
int NumFeatures = TV_DECLTYPE(V)::value;
int Num0 = MaxThreads / NumFeatures;
int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0));
dim3 blocks(tv::div_up(out.dim(1),
int64_t(
NumFeatures)
)
, tv::div_up(nhot,
int64_t(
Num0))
)
;
dim3 threads(NumFeatures, Num0);
dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
}});
}});
if (!found){{
if (!found){{
int NumFeatures = 16;
int NumFeatures = 16;
int Num0 = MaxThreads / NumFeatures;
int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0));
dim3 blocks(tv::div_up(out.dim(1),
int64_t(
NumFeatures)
)
, tv::div_up(nhot,
int64_t(
Num0))
)
;
dim3 threads(NumFeatures, Num0);
dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
}}
}}
...
@@ -331,14 +330,14 @@ class IndiceMaxPool(pccm.Class):
...
@@ -331,14 +330,14 @@ class IndiceMaxPool(pccm.Class):
// if a value is found, other value won't be executed.
// if a value is found, other value won't be executed.
int NumFeatures = TV_DECLTYPE(V)::value;
int NumFeatures = TV_DECLTYPE(V)::value;
int Num0 = MaxThreads / NumFeatures;
int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0));
dim3 blocks(tv::div_up(out.dim(1),
int64_t(
NumFeatures)
)
, tv::div_up(nhot,
int64_t(
Num0))
)
;
dim3 threads(NumFeatures, Num0);
dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
}});
}});
if (!found){{
if (!found){{
int NumFeatures = 16;
int NumFeatures = 16;
int Num0 = MaxThreads / NumFeatures;
int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0));
dim3 blocks(tv::div_up(out.dim(1),
int64_t(
NumFeatures)
)
, tv::div_up(nhot,
int64_t(
Num0))
)
;
dim3 threads(NumFeatures, Num0);
dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
}}
}}
...
...
spconv/csrc/sparse/pointops.py
View file @
7af751dc
...
@@ -126,6 +126,7 @@ class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
...
@@ -126,6 +126,7 @@ class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
super
().
__init__
()
super
().
__init__
()
self
.
add_dependency
(
TensorView
,
TensorViewHashKernel
)
self
.
add_dependency
(
TensorView
,
TensorViewHashKernel
)
self
.
add_param_class
(
"layout_ns"
,
layout
,
"Layout"
)
self
.
add_param_class
(
"layout_ns"
,
layout
,
"Layout"
)
self
.
add_include
(
"tensorview/hash/ops.h"
)
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
ndim
=
ndim
self
.
ndim
=
ndim
self
.
zyx
=
zyx
self
.
zyx
=
zyx
...
@@ -447,7 +448,7 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
...
@@ -447,7 +448,7 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
TV_ASSERT_RT_ERR(point_indice_data.dim(0) >= points.dim(0), "point_indice_data too small")
TV_ASSERT_RT_ERR(point_indice_data.dim(0) >= points.dim(0), "point_indice_data too small")
num_per_voxel.zero_(ctx);
num_per_voxel.zero_(ctx);
table_t hash = table_t(hashdata.data_ptr<pair_t>(), expected_hash_data_num);
table_t hash = table_t(hashdata.data_ptr<pair_t>(), expected_hash_data_num);
hash
.
clear
(
custream);
tv::
hash
::
clear
_map(hash,
custream);
auto launcher = tv::cuda::Launch(points.dim(0), custream);
auto launcher = tv::cuda::Launch(points.dim(0), custream);
launcher(kernel::build_hash_table<table_t>, hash, points.data_ptr<const
{
self
.
dtype
}
>(),
launcher(kernel::build_hash_table<table_t>, hash, points.data_ptr<const
{
self
.
dtype
}
>(),
point_indice_data.data_ptr<int64_t>(),
point_indice_data.data_ptr<int64_t>(),
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment