Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
7af751dc
Commit
7af751dc
authored
Jul 12, 2022
by
yan.yan
Browse files
sync
parent
647927ce
Changes
28
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1739 additions
and
608 deletions
+1739
-608
CHANGELOG.md
CHANGELOG.md
+5
-0
example/simple_hash.py
example/simple_hash.py
+1
-1
spconv/algo.py
spconv/algo.py
+197
-38
spconv/algocore.py
spconv/algocore.py
+133
-0
spconv/build.py
spconv/build.py
+9
-10
spconv/constants.py
spconv/constants.py
+20
-0
spconv/core.py
spconv/core.py
+86
-54
spconv/core_cc/csrc/sparse/all/__init__.pyi
spconv/core_cc/csrc/sparse/all/__init__.pyi
+71
-6
spconv/core_cc/csrc/sparse/alloc.pyi
spconv/core_cc/csrc/sparse/alloc.pyi
+54
-0
spconv/core_cc/cumm/__init__.pyi
spconv/core_cc/cumm/__init__.pyi
+14
-0
spconv/core_cc/cumm/conv/main.pyi
spconv/core_cc/cumm/conv/main.pyi
+1
-102
spconv/core_cc/cumm/gemm/main.pyi
spconv/core_cc/cumm/gemm/main.pyi
+3
-140
spconv/csrc/hash/core.py
spconv/csrc/hash/core.py
+35
-32
spconv/csrc/sparse/all.py
spconv/csrc/sparse/all.py
+474
-28
spconv/csrc/sparse/alloc.py
spconv/csrc/sparse/alloc.py
+195
-0
spconv/csrc/sparse/convops.py
spconv/csrc/sparse/convops.py
+207
-0
spconv/csrc/sparse/cpu_core.py
spconv/csrc/sparse/cpu_core.py
+4
-9
spconv/csrc/sparse/indices.py
spconv/csrc/sparse/indices.py
+220
-178
spconv/csrc/sparse/maxpool.py
spconv/csrc/sparse/maxpool.py
+8
-9
spconv/csrc/sparse/pointops.py
spconv/csrc/sparse/pointops.py
+2
-1
No files found.
CHANGELOG.md
View file @
7af751dc
# Changelog
# Changelog
## [2.1.22] - 2022-4-14
### Added
-
add full nvrtc support
-
add support for large spatial shape and batch size. if detect large shape, we use int64 instead of int32 when hashing.
## [2.1.21] - 2021-12-9
## [2.1.21] - 2021-12-9
### Added
### Added
-
add sm_37
-
add sm_37
...
...
example/simple_hash.py
View file @
7af751dc
...
@@ -56,7 +56,7 @@ def main():
...
@@ -56,7 +56,7 @@ def main():
is_empty
=
table
.
insert_exist_keys
(
keys
,
values
)
is_empty
=
table
.
insert_exist_keys
(
keys
,
values
)
ks
,
vs
,
cnt
=
table
.
items
()
ks
,
vs
,
cnt
=
table
.
items
()
cnt_item
=
cnt
.
item
()
cnt_item
=
cnt
.
item
()
print
(
cnt
,
ks
[:
cnt_item
],
vs
[:
cnt_item
])
print
(
cnt
,
ks
[:
cnt_item
],
vs
[:
cnt_item
]
,
is_empty
.
dtype
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
spconv/algo.py
View file @
7af751dc
...
@@ -12,28 +12,50 @@
...
@@ -12,28 +12,50 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
enum
import
Enum
import
contextlib
from
cumm
import
tensorview
as
tv
from
typing
import
Dict
,
List
,
Set
,
Tuple
,
Union
from
spconv.core_cc.cumm.gemm.main
import
GemmAlgoDesp
,
GemmMainUnitTest
,
GemmParams
from
spconv.core_cc.cumm.conv.main
import
ConvAlgoDesp
,
ConvMainUnitTest
,
ConvParams
from
cumm.conv.bases
import
ConvLayout
,
ConvLayoutType
,
ConvOpType
from
cumm.gemm.algospec.core
import
GemmAlgo
,
ShuffleStrideType
,
get_min_arch_of_algo_str
,
get_available_algo_str_from_arch
from
cumm.gemm.codeops
import
group_by
,
div_up
from
spconv.constants
import
NDIM_DONT_CARE
,
SPCONV_BWD_SPLITK
from
typing
import
Optional
import
time
import
time
from
enum
import
Enum
from
threading
import
Lock
from
threading
import
Lock
import
contextlib
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
from
spconv.core
import
ConvAlgo
,
AlgoHint
from
cumm
import
tensorview
as
tv
from
cumm.conv.bases
import
ConvLayout
,
ConvLayoutType
,
ConvOpType
from
cumm.conv.kernel
import
ConvKernel
from
cumm.gemm.kernel
import
GemmKernel
from
cumm.gemm.algospec.core
import
(
GemmAlgo
,
ShuffleStrideType
,
get_available_algo_str_from_arch
,
get_min_arch_of_algo_str
)
from
cumm.gemm.codeops
import
div_up
,
group_by
from
cumm.nvrtc
import
CummNVRTCModule
,
get_cudadevrt_path
from
cumm.tensorview.gemm
import
ConvAlgoDesp
from
cumm.tensorview.gemm
import
ConvOpType
as
ConvOpTypeCpp
from
cumm.tensorview.gemm
import
ConvParams
,
GemmAlgoDesp
,
GemmParams
from
cumm
import
dtypes
from
spconv.constants
import
(
NDIM_DONT_CARE
,
SPCONV_BWD_SPLITK
,
SPCONV_NVRTC_MODE
,
SPCONV_DEBUG_NVRTC_KERNELS
)
from
spconv.core
import
ALL_IMPGEMM_PARAMS
,
AlgoHint
,
ConvAlgo
from
spconv.core_cc.cumm.conv.main
import
ConvMainUnitTest
from
spconv.core_cc.cumm.gemm.main
import
GemmMainUnitTest
from
spconv.cppconstants
import
COMPILED_CUDA_ARCHS
from
cumm.tensorview.gemm
import
NVRTCParams
from
spconv.tools
import
CUDAKernelTimer
from
spconv.tools
import
CUDAKernelTimer
from
cumm.gemm.constants
import
NVRTCConstants
,
NVRTCMode
from
spconv
import
algocore
from
cumm.conv.main
import
gen_gemm_kernels
as
gen_conv_kernels
from
cumm.gemm.main
import
gen_gemm_kernels
ALL_ALGO_DESPS
=
GemmMainUnitTest
.
get_all_algo_desp
()
ALL_ALGO_DESPS
=
GemmMainUnitTest
.
get_all_algo_desp
()
ALL_CONV_ALGO_DESPS
=
ConvMainUnitTest
.
get_all_conv_algo_desp
()
ALL_CONV_ALGO_DESPS
=
ConvMainUnitTest
.
get_all_conv_algo_desp
()
_GEMM_STATIC_KEY
=
Tuple
[
bool
,
bool
,
bool
,
int
,
int
,
int
,
str
,
str
]
_GEMM_STATIC_KEY
=
Tuple
[
bool
,
bool
,
bool
,
int
,
int
,
int
,
str
,
str
]
class
SimpleGemmAlgoMeta
:
class
SimpleGemmAlgoMeta
:
def
__init__
(
self
,
tile_ms
:
List
[
int
],
tile_ns
:
List
[
int
],
def
__init__
(
self
,
tile_ms
:
List
[
int
],
tile_ns
:
List
[
int
],
tile_ks
:
List
[
int
],
tile_ks
:
List
[
int
],
...
@@ -45,22 +67,68 @@ class SimpleGemmAlgoMeta:
...
@@ -45,22 +67,68 @@ class SimpleGemmAlgoMeta:
class
BestAlgoByProfile
:
class
BestAlgoByProfile
:
def
__init__
(
self
,
algo_desp
:
GemmAlgoDesp
,
splitk
:
int
=
1
)
->
None
:
def
__init__
(
self
,
algo_desp
:
GemmAlgoDesp
,
arch
:
Tuple
[
int
,
int
],
splitk
:
int
=
1
)
->
None
:
self
.
algo_desp
=
algo_desp
self
.
algo_desp
=
algo_desp
self
.
splitk
=
splitk
self
.
splitk
=
splitk
self
.
arch
=
arch
class
BestConvAlgoByProfile
:
class
BestConvAlgoByProfile
:
def
__init__
(
self
,
algo_desp
:
ConvAlgoDesp
,
splitk
:
int
=
1
)
->
None
:
def
__init__
(
self
,
algo_desp
:
ConvAlgoDesp
,
arch
:
Tuple
[
int
,
int
],
splitk
:
int
=
1
)
->
None
:
self
.
algo_desp
=
algo_desp
self
.
algo_desp
=
algo_desp
self
.
splitk
=
splitk
self
.
splitk
=
splitk
self
.
arch
=
arch
def
_get_nvrtc_params
(
mod
:
CummNVRTCModule
,
ker
:
Union
[
GemmKernel
,
ConvKernel
],
kernel_name
:
str
):
nvrtc_mode
=
SPCONV_NVRTC_MODE
nvrtc_params
=
tv
.
gemm
.
NVRTCParams
()
nvrtc_params
.
cumodule
=
mod
.
get_cpp_object
()
nvrtc_params
.
mode
=
nvrtc_mode
.
value
nvrtc_params
.
num_threads
=
ker
.
num_threads
nvrtc_params
.
smem_size
=
ker
.
smem_size
ns
=
ker
.
namespace
if
nvrtc_mode
==
NVRTCMode
.
DynamicParallism
:
nvrtc_params
.
kernel_name
=
mod
.
get_lowered_name
(
f
"
{
ns
}
::nvrtc_kernel"
)
elif
nvrtc_mode
==
NVRTCMode
.
KernelAndCPU
:
nvrtc_params
.
kernel_name
=
mod
.
get_lowered_name
(
f
"
{
ns
}
::
{
kernel_name
}
"
)
nvrtc_params
.
init_kernel_name
=
mod
.
get_lowered_name
(
f
"
{
ns
}
::nvrtc_kernel_cpu_out"
)
nvrtc_params
.
param_size
=
mod
.
const_values
[
f
"
{
ns
}
::
{
NVRTCConstants
.
SIZEOF_KEY
}
"
]
nvrtc_params
.
param_storage
=
tv
.
empty
([
nvrtc_params
.
param_size
],
tv
.
uint8
,
0
)
nvrtc_params
.
param_storage_cpu
=
tv
.
empty
(
[
nvrtc_params
.
param_size
],
tv
.
uint8
,
-
1
,
pinned
=
True
)
elif
nvrtc_mode
==
NVRTCMode
.
Direct
:
nvrtc_params
.
kernel_name
=
mod
.
get_lowered_name
(
f
"
{
ns
}
::
{
kernel_name
}
"
)
elif
nvrtc_mode
==
NVRTCMode
.
ConstantMemory
:
nvrtc_params
.
kernel_name
=
mod
.
get_lowered_name
(
f
"
{
ns
}
::
{
kernel_name
}
"
)
nvrtc_params
.
init_kernel_name
=
mod
.
get_lowered_name
(
f
"
{
ns
}
::nvrtc_kernel_cpu_out"
)
nvrtc_params
.
param_size
=
mod
.
const_values
[
f
"
{
ns
}
::
{
NVRTCConstants
.
SIZEOF_KEY
}
"
]
nvrtc_params
.
constant_name
=
mod
.
get_lowered_name
(
f
"&
{
ns
}
::
{
NVRTCConstants
.
CONSTANT_PARAM_KEY
}
"
)
nvrtc_params
.
param_storage
=
tv
.
empty
([
nvrtc_params
.
param_size
],
tv
.
uint8
,
0
)
else
:
raise
NotImplementedError
return
nvrtc_params
class
SimpleGemm
:
class
SimpleGemm
:
def
__init__
(
self
,
desps
:
List
[
GemmAlgoDesp
])
->
None
:
def
__init__
(
self
,
prebuilt_desps
:
List
[
GemmAlgoDesp
])
->
None
:
self
.
desps
=
desps
all_desps
=
[
algocore
.
get_conv_algo_desp_from_param
(
p
)
for
p
in
ALL_IMPGEMM_PARAMS
]
self
.
prebuilt_desps
=
prebuilt_desps
self
.
prebuilt_desp_names
=
{
str
(
d
)
for
d
in
prebuilt_desps
}
if
SPCONV_DEBUG_NVRTC_KERNELS
:
self
.
prebuilt_desp_names
.
clear
()
self
.
lock
=
Lock
()
self
.
lock
=
Lock
()
self
.
static_key_to_desps
=
group_by
(
self
.
get_static_key
,
desps
)
self
.
static_key_to_desps
=
group_by
(
self
.
get_static_key
,
all_
desps
)
self
.
static_key_to_meta
:
Dict
[
_GEMM_STATIC_KEY
,
self
.
static_key_to_meta
:
Dict
[
_GEMM_STATIC_KEY
,
SimpleGemmAlgoMeta
]
=
{}
SimpleGemmAlgoMeta
]
=
{}
for
k
,
static_desps
in
self
.
static_key_to_desps
.
items
():
for
k
,
static_desps
in
self
.
static_key_to_desps
.
items
():
...
@@ -94,15 +162,44 @@ class SimpleGemm:
...
@@ -94,15 +162,44 @@ class SimpleGemm:
self
.
mn_cache
:
Dict
[
Tuple
[
int
,
int
,
int
,
int
,
int
],
self
.
mn_cache
:
Dict
[
Tuple
[
int
,
int
,
int
,
int
,
int
],
BestAlgoByProfile
]
=
{}
# for backward weight
BestAlgoByProfile
]
=
{}
# for backward weight
self
.
_nvrtc_caches
:
Dict
[
Tuple
[
str
,
Tuple
[
int
,
int
]],
NVRTCParams
]
=
{}
@
staticmethod
@
staticmethod
def
get_static_key
(
d
:
GemmAlgoDesp
)
->
_GEMM_STATIC_KEY
:
def
get_static_key
(
d
:
GemmAlgoDesp
)
->
_GEMM_STATIC_KEY
:
return
(
d
.
trans_a
,
d
.
trans_b
,
d
.
trans_c
,
d
.
dtype_a
,
d
.
dtype_b
,
return
(
d
.
trans_a
,
d
.
trans_b
,
d
.
trans_c
,
d
.
dtype_a
,
d
.
dtype_b
,
d
.
dtype_c
,
d
.
shuffle_type
,
d
.
algo
)
d
.
dtype_c
,
d
.
shuffle_type
.
value
,
d
.
algo
)
def
device_synchronize
(
self
):
def
device_synchronize
(
self
):
return
GemmMainUnitTest
.
device_synchronize
()
return
GemmMainUnitTest
.
device_synchronize
()
def
_compile_nvrtc_module
(
self
,
desp
:
GemmAlgoDesp
):
params
=
algocore
.
get_gemm_param_from_desp
(
desp
)
kernel
=
gen_gemm_kernels
(
params
,
SPCONV_NVRTC_MODE
)
kernel
.
namespace
=
"spconv"
custom_names
=
[]
if
SPCONV_NVRTC_MODE
==
NVRTCMode
.
ConstantMemory
:
custom_names
=
[
f
"&
{
kernel
.
namespace
}
::
{
NVRTCConstants
.
CONSTANT_PARAM_KEY
}
"
]
cudadevrt
=
""
if
SPCONV_NVRTC_MODE
==
NVRTCMode
.
DynamicParallism
:
cudadevrt_p
=
get_cudadevrt_path
()
assert
cudadevrt_p
is
not
None
,
"DynamicParallism must have cudadevrt"
cudadevrt
=
str
(
cudadevrt_p
)
mod
=
CummNVRTCModule
([
kernel
],
cudadevrt_path
=
cudadevrt
,
verbose
=
False
,
custom_names
=
custom_names
)
mod
.
load
()
return
mod
,
kernel
def
_cached_get_nvrtc_params
(
self
,
desp
:
GemmAlgoDesp
,
arch
:
Tuple
[
int
,
int
]):
key
=
(
str
(
desp
),
arch
)
if
key
in
self
.
_nvrtc_caches
:
return
self
.
_nvrtc_caches
[
key
]
mod
,
ker
=
self
.
_compile_nvrtc_module
(
desp
)
nvrtc_params
=
_get_nvrtc_params
(
mod
,
ker
,
"gemm_kernel"
)
self
.
_nvrtc_caches
[
key
]
=
nvrtc_params
return
nvrtc_params
def
get_all_available
(
def
get_all_available
(
self
,
self
,
a
:
tv
.
Tensor
,
a
:
tv
.
Tensor
,
...
@@ -135,6 +232,11 @@ class SimpleGemm:
...
@@ -135,6 +232,11 @@ class SimpleGemm:
ldb
=
b
.
dim
(
1
)
ldb
=
b
.
dim
(
1
)
ldc
=
c
.
dim
(
1
)
ldc
=
c
.
dim
(
1
)
if
desp
.
supported_ldx
(
lda
,
ldb
,
ldc
):
if
desp
.
supported_ldx
(
lda
,
ldb
,
ldc
):
if
arch
not
in
COMPILED_CUDA_ARCHS
:
desp
=
desp
.
copy
()
desp
.
is_nvrtc
=
True
if
SPCONV_DEBUG_NVRTC_KERNELS
:
desp
.
is_nvrtc
=
True
finally_algos
.
append
(
desp
)
finally_algos
.
append
(
desp
)
return
finally_algos
return
finally_algos
...
@@ -334,6 +436,8 @@ class SimpleGemm:
...
@@ -334,6 +436,8 @@ class SimpleGemm:
if
desp
.
split_k_serial
and
hint
&
AlgoHint
.
BackwardWeight
.
value
:
if
desp
.
split_k_serial
and
hint
&
AlgoHint
.
BackwardWeight
.
value
:
split_k_slices
=
max
(
min
(
32
,
k
//
128
),
1
)
split_k_slices
=
max
(
min
(
32
,
k
//
128
),
1
)
params
=
GemmParams
()
params
=
GemmParams
()
if
desp
.
is_nvrtc
and
str
(
desp
)
not
in
self
.
prebuilt_desp_names
:
params
.
nvrtc_params
=
self
.
_cached_get_nvrtc_params
(
desp
,
arch
)
params
.
a
=
a
params
.
a
=
a
params
.
b
=
b
params
.
b
=
b
params
.
c
=
c_
params
.
c
=
c_
...
@@ -361,7 +465,7 @@ class SimpleGemm:
...
@@ -361,7 +465,7 @@ class SimpleGemm:
times
.
append
(
np
.
mean
(
this_times
[
1
:]))
times
.
append
(
np
.
mean
(
this_times
[
1
:]))
spk_speeds
.
append
(
times
[
-
1
])
spk_speeds
.
append
(
times
[
-
1
])
all_profile_res
.
append
(
BestAlgoByProfile
(
desp
,
splitk
=
spk
))
all_profile_res
.
append
(
BestAlgoByProfile
(
desp
,
arch
,
splitk
=
spk
))
min_time
=
1000
min_time
=
1000
min_idx
=
-
1
min_idx
=
-
1
...
@@ -421,6 +525,9 @@ class SimpleGemm:
...
@@ -421,6 +525,9 @@ class SimpleGemm:
if
profile_res
.
splitk
>
1
:
if
profile_res
.
splitk
>
1
:
split_k_slices
=
profile_res
.
splitk
split_k_slices
=
profile_res
.
splitk
params
=
GemmParams
()
params
=
GemmParams
()
if
algo_desp
.
is_nvrtc
and
str
(
algo_desp
)
not
in
self
.
prebuilt_desp_names
:
params
.
nvrtc_params
=
self
.
_cached_get_nvrtc_params
(
algo_desp
,
profile_res
.
arch
)
params
.
a
=
a
params
.
a
=
a
params
.
b
=
b
params
.
b
=
b
params
.
c
=
c
params
.
c
=
c
...
@@ -461,11 +568,14 @@ _CONV_STATIC_KEY = Tuple[int, int, int, int, int, int, int, int, int, str, int]
...
@@ -461,11 +568,14 @@ _CONV_STATIC_KEY = Tuple[int, int, int, int, int, int, int, int, int, str, int]
class
SimpleConv
:
class
SimpleConv
:
def
__init__
(
self
,
desps
:
List
[
ConvAlgoDesp
])
->
None
:
def
__init__
(
self
,
prebuilt_desps
:
List
[
ConvAlgoDesp
])
->
None
:
self
.
desps
=
desps
all_desps
=
[
algocore
.
get_conv_algo_desp_from_param
(
p
)
for
p
in
ALL_IMPGEMM_PARAMS
]
self
.
prebuilt_desps
=
prebuilt_desps
self
.
prebuilt_desp_names
=
{
str
(
d
)
for
d
in
prebuilt_desps
}
self
.
prebuilt_desp_names
.
clear
()
self
.
lock
=
Lock
()
self
.
lock
=
Lock
()
self
.
static_key_to_desps
=
group_by
(
self
.
get_static_key
,
desps
)
self
.
static_key_to_desps
=
group_by
(
self
.
get_static_key
,
all_
desps
)
self
.
static_key_to_meta
:
Dict
[
_CONV_STATIC_KEY
,
self
.
static_key_to_meta
:
Dict
[
_CONV_STATIC_KEY
,
SimpleGemmAlgoMeta
]
=
{}
SimpleGemmAlgoMeta
]
=
{}
for
k
,
static_desps
in
self
.
static_key_to_desps
.
items
():
for
k
,
static_desps
in
self
.
static_key_to_desps
.
items
():
...
@@ -500,28 +610,36 @@ class SimpleConv:
...
@@ -500,28 +610,36 @@ class SimpleConv:
self
.
kc_wgrad_cache
:
Dict
[
Tuple
[
int
,
int
,
int
,
int
,
int
,
int
,
int
,
self
.
kc_wgrad_cache
:
Dict
[
Tuple
[
int
,
int
,
int
,
int
,
int
,
int
,
int
,
int
],
BestConvAlgoByProfile
]
=
{
int
],
BestConvAlgoByProfile
]
=
{
}
# for backward weight
}
# for backward weight
self
.
_nvrtc_caches
:
Dict
[
Tuple
[
str
,
Tuple
[
int
,
int
]],
NVRTCParams
]
=
{}
@
staticmethod
@
staticmethod
def
get_static_key
(
d
:
ConvAlgoDesp
)
->
_CONV_STATIC_KEY
:
def
get_static_key
(
d
:
ConvAlgoDesp
)
->
_CONV_STATIC_KEY
:
return
(
d
.
layout_i
,
d
.
layout_w
,
d
.
layout_o
,
d
.
interleave_i
,
return
(
d
.
layout_i
.
value
,
d
.
layout_w
.
value
,
d
.
layout_o
.
value
,
d
.
interleave_w
,
d
.
interleave_o
,
d
.
dtype_input
,
d
.
dtype_weight
,
d
.
interleave_i
,
d
.
interleave_w
,
d
.
interleave_o
,
d
.
dtype_input
,
d
.
dtype_output
,
d
.
algo
,
d
.
op_type
)
d
.
dtype_weight
,
d
.
dtype_output
,
d
.
algo
,
d
.
op_type
.
value
)
def
device_synchronize
(
self
):
def
device_synchronize
(
self
):
return
GemmMainUnitTest
.
device_synchronize
()
return
GemmMainUnitTest
.
device_synchronize
()
def
get_all_available
(
self
,
inp
:
tv
.
Tensor
,
weight
:
tv
.
Tensor
,
def
get_all_available
(
self
,
out
:
tv
.
Tensor
,
layout_i
:
ConvLayout
,
inp
:
tv
.
Tensor
,
layout_w
:
ConvLayout
,
layout_o
:
ConvLayout
,
weight
:
tv
.
Tensor
,
arch
:
Tuple
[
int
,
int
],
op_type
:
ConvOpType
,
out
:
tv
.
Tensor
,
mask_width
:
int
,
fp32_accum
:
Optional
[
bool
]
=
None
):
layout_i
:
ConvLayout
,
layout_w
:
ConvLayout
,
layout_o
:
ConvLayout
,
arch
:
Tuple
[
int
,
int
],
op_type
:
ConvOpType
,
mask_width
:
int
,
fp32_accum
:
Optional
[
bool
]
=
None
):
avail_algos
=
get_available_algo_str_from_arch
(
arch
)
avail_algos
=
get_available_algo_str_from_arch
(
arch
)
finally_algos
:
List
[
ConvAlgoDesp
]
=
[]
finally_algos
:
List
[
ConvAlgoDesp
]
=
[]
is_fp16
=
inp
.
dtype
==
tv
.
float16
and
weight
.
dtype
==
tv
.
float16
and
out
.
dtype
==
tv
.
float16
is_fp16
=
inp
.
dtype
==
tv
.
float16
and
weight
.
dtype
==
tv
.
float16
and
out
.
dtype
==
tv
.
float16
use_f32_as_accum
=
False
use_f32_as_accum
=
False
kv
=
int
(
np
.
prod
(
weight
.
shape
[
1
:
-
1
]))
kv
=
int
(
np
.
prod
(
weight
.
shape
[
1
:
-
1
]))
# for 3d conv, if reduce axis is too large, may cause nan during
# for 3d conv, if reduce axis is too large, may cause nan during
# forward.
# forward.
if
is_fp16
:
if
is_fp16
:
if
fp32_accum
is
None
:
if
fp32_accum
is
None
:
...
@@ -551,7 +669,7 @@ class SimpleConv:
...
@@ -551,7 +669,7 @@ class SimpleConv:
if
use_f32_as_accum
:
if
use_f32_as_accum
:
if
desp
.
dacc
==
tv
.
float16
:
if
desp
.
dacc
==
tv
.
float16
:
continue
continue
ldi
=
inp
.
dim
(
-
1
)
ldi
=
inp
.
dim
(
-
1
)
ldw
=
weight
.
dim
(
-
1
)
ldw
=
weight
.
dim
(
-
1
)
ldo
=
out
.
dim
(
-
1
)
ldo
=
out
.
dim
(
-
1
)
...
@@ -560,6 +678,11 @@ class SimpleConv:
...
@@ -560,6 +678,11 @@ class SimpleConv:
assert
mask_width
>
0
assert
mask_width
>
0
mask_width_valid
=
mask_width
%
desp
.
tile_shape
[
2
]
==
0
mask_width_valid
=
mask_width
%
desp
.
tile_shape
[
2
]
==
0
if
desp
.
supported_ldx_conv
(
ldi
,
ldw
,
ldo
)
and
mask_width_valid
:
if
desp
.
supported_ldx_conv
(
ldi
,
ldw
,
ldo
)
and
mask_width_valid
:
if
arch
not
in
COMPILED_CUDA_ARCHS
:
desp
=
desp
.
copy
()
desp
.
is_nvrtc
=
True
if
SPCONV_DEBUG_NVRTC_KERNELS
:
desp
.
is_nvrtc
=
True
finally_algos
.
append
(
desp
)
finally_algos
.
append
(
desp
)
return
finally_algos
return
finally_algos
...
@@ -592,6 +715,34 @@ class SimpleConv:
...
@@ -592,6 +715,34 @@ class SimpleConv:
return
desp
.
query_conv_workspace_size
(
mnk
[
0
],
mnk
[
1
],
mnk
[
2
],
splitk
,
return
desp
.
query_conv_workspace_size
(
mnk
[
0
],
mnk
[
1
],
mnk
[
2
],
splitk
,
kv
)
kv
)
def
_compile_nvrtc_module
(
self
,
desp
:
ConvAlgoDesp
):
params
=
algocore
.
get_conv_param_from_desp
(
desp
)
kernel
=
gen_conv_kernels
(
params
,
SPCONV_NVRTC_MODE
)
kernel
.
namespace
=
"spconv"
custom_names
=
[]
if
SPCONV_NVRTC_MODE
==
NVRTCMode
.
ConstantMemory
:
custom_names
=
[
f
"&
{
kernel
.
namespace
}
::
{
NVRTCConstants
.
CONSTANT_PARAM_KEY
}
"
]
cudadevrt
=
""
if
SPCONV_NVRTC_MODE
==
NVRTCMode
.
DynamicParallism
:
cudadevrt_p
=
get_cudadevrt_path
()
assert
cudadevrt_p
is
not
None
,
"DynamicParallism must have cudadevrt"
cudadevrt
=
str
(
cudadevrt_p
)
mod
=
CummNVRTCModule
([
kernel
],
cudadevrt_path
=
cudadevrt
,
verbose
=
False
,
custom_names
=
custom_names
)
mod
.
load
()
return
mod
,
kernel
def
_cached_get_nvrtc_params
(
self
,
desp
:
ConvAlgoDesp
,
arch
:
Tuple
[
int
,
int
]):
key
=
(
str
(
desp
),
arch
)
if
key
in
self
.
_nvrtc_caches
:
return
self
.
_nvrtc_caches
[
key
]
mod
,
ker
=
self
.
_compile_nvrtc_module
(
desp
)
nvrtc_params
=
_get_nvrtc_params
(
mod
,
ker
,
"conv_kernel"
)
self
.
_nvrtc_caches
[
key
]
=
nvrtc_params
return
nvrtc_params
def
tune_and_cache
(
self
,
def
tune_and_cache
(
self
,
op_type
:
ConvOpType
,
op_type
:
ConvOpType
,
inp
:
tv
.
Tensor
,
inp
:
tv
.
Tensor
,
...
@@ -613,7 +764,7 @@ class SimpleConv:
...
@@ -613,7 +764,7 @@ class SimpleConv:
stream
:
int
=
0
,
stream
:
int
=
0
,
fp32_accum
:
Optional
[
bool
]
=
None
):
fp32_accum
:
Optional
[
bool
]
=
None
):
avail
=
self
.
get_all_available
(
inp
,
weight
,
output
,
layout_i
,
layout_w
,
avail
=
self
.
get_all_available
(
inp
,
weight
,
output
,
layout_i
,
layout_w
,
layout_o
,
arch
,
op_type
,
mask_width
,
layout_o
,
arch
,
op_type
,
mask_width
,
fp32_accum
)
fp32_accum
)
inp
=
inp
.
clone
()
inp
=
inp
.
clone
()
weight
=
weight
.
clone
()
weight
=
weight
.
clone
()
...
@@ -626,7 +777,10 @@ class SimpleConv:
...
@@ -626,7 +777,10 @@ class SimpleConv:
all_profile_res
:
List
[
BestConvAlgoByProfile
]
=
[]
all_profile_res
:
List
[
BestConvAlgoByProfile
]
=
[]
for
desp
in
avail
:
for
desp
in
avail
:
# for sparse conv, ndim isn't used, so we just provide a constant value.
# for sparse conv, ndim isn't used, so we just provide a constant value.
params
=
ConvParams
(
NDIM_DONT_CARE
,
op_type
.
value
)
params
=
ConvParams
(
NDIM_DONT_CARE
,
ConvOpTypeCpp
(
op_type
.
value
))
if
desp
.
is_nvrtc
and
str
(
desp
)
not
in
self
.
prebuilt_desp_names
:
params
.
nvrtc_params
=
self
.
_cached_get_nvrtc_params
(
desp
,
arch
)
params
.
conv_algo_desp
=
desp
params
.
conv_algo_desp
=
desp
params
.
input
=
inp
params
.
input
=
inp
params
.
weight
=
weight
.
view
([
channel_k
,
-
1
,
channel_c
])
params
.
weight
=
weight
.
view
([
channel_k
,
-
1
,
channel_c
])
...
@@ -657,13 +811,16 @@ class SimpleConv:
...
@@ -657,13 +811,16 @@ class SimpleConv:
GemmMainUnitTest
.
stream_synchronize
(
stream
)
GemmMainUnitTest
.
stream_synchronize
(
stream
)
t
=
time
.
time
()
t
=
time
.
time
()
params
.
split_k_slices
=
spk
params
.
split_k_slices
=
spk
ConvMainUnitTest
.
implicit_gemm2
(
params
)
if
desp
.
is_nvrtc
and
str
(
desp
)
not
in
self
.
prebuilt_desp_names
:
tv
.
gemm
.
run_nvrtc_conv_kernel
(
params
)
else
:
ConvMainUnitTest
.
implicit_gemm2
(
params
)
GemmMainUnitTest
.
stream_synchronize
(
stream
)
GemmMainUnitTest
.
stream_synchronize
(
stream
)
this_times
.
append
(
time
.
time
()
-
t
)
this_times
.
append
(
time
.
time
()
-
t
)
times
.
append
(
np
.
mean
(
this_times
[
1
:]))
times
.
append
(
np
.
mean
(
this_times
[
1
:]))
spk_speeds
.
append
(
times
[
-
1
])
spk_speeds
.
append
(
times
[
-
1
])
all_profile_res
.
append
(
BestConvAlgoByProfile
(
desp
,
splitk
=
spk
))
all_profile_res
.
append
(
BestConvAlgoByProfile
(
desp
,
arch
,
splitk
=
spk
))
if
not
all_profile_res
:
if
not
all_profile_res
:
raise
ValueError
(
"can't find suitable algorithm for"
,
op_type
)
raise
ValueError
(
"can't find suitable algorithm for"
,
op_type
)
min_time
=
1000
min_time
=
1000
...
@@ -720,7 +877,9 @@ class SimpleConv:
...
@@ -720,7 +877,9 @@ class SimpleConv:
op_type_value
=
op_type
op_type_value
=
op_type
else
:
else
:
op_type_value
=
op_type
.
value
op_type_value
=
op_type
.
value
params
=
ConvParams
(
NDIM_DONT_CARE
,
op_type_value
)
params
=
ConvParams
(
NDIM_DONT_CARE
,
ConvOpTypeCpp
(
op_type_value
))
if
algo_desp
.
is_nvrtc
and
str
(
algo_desp
)
not
in
self
.
prebuilt_desp_names
:
params
.
nvrtc_params
=
self
.
_cached_get_nvrtc_params
(
algo_desp
,
profile_res
.
arch
)
params
.
conv_algo_desp
=
profile_res
.
algo_desp
params
.
conv_algo_desp
=
profile_res
.
algo_desp
params
.
input
=
inp
params
.
input
=
inp
params
.
verbose
=
verbose
params
.
verbose
=
verbose
...
...
spconv/algocore.py
0 → 100644
View file @
7af751dc
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
cumm.conv.bases
import
ConvLayout
,
ConvLayoutType
,
ConvOpType
from
cumm.gemm.algospec.core
import
(
GemmAlgo
,
ShuffleStrideType
)
from
cumm.tensorview.gemm
import
ConvAlgoDesp
from
cumm.tensorview.gemm
import
ConvIterAlgo
as
ConvIterAlgoCpp
from
cumm.tensorview.gemm
import
ConvOpType
as
ConvOpTypeCpp
from
cumm.tensorview.gemm
import
ConvLayoutType
as
ConvLayoutTypeCpp
from
cumm.tensorview.gemm
import
ShuffleStrideType
as
ShuffleStrideTypeCpp
from
cumm.tensorview.gemm
import
ConvParams
,
GemmAlgoDesp
,
GemmParams
from
cumm.gemm.main
import
GemmAlgoParams
from
cumm.conv.main
import
ConvAlgoParams
,
ConvIterAlgo
from
cumm
import
dtypes
from
cumm.conv.bases
import
(
NCHW
,
NHWC
,
ConvIterAlgo
,
ConvLayout
,
ConvLayoutType
,
ConvMode
,
ConvOpType
)
from
cumm.gemm.core
import
MetaArray
from
cumm.gemm.algospec
import
TensorOp
def
_assign_gemm_desp_props
(
desp
:
Union
[
ConvAlgoDesp
,
GemmAlgoDesp
],
p
:
Union
[
GemmAlgoParams
,
ConvAlgoParams
]):
desp
.
dtype_a
=
p
.
dtype_a
.
tv_dtype
desp
.
dtype_b
=
p
.
dtype_a
.
tv_dtype
desp
.
dtype_c
=
p
.
dtype_a
.
tv_dtype
desp
.
dacc
=
p
.
dtype_acc
.
tv_dtype
desp
.
dcomp
=
p
.
dtype_comp
.
tv_dtype
desp
.
trans_a
=
p
.
trans_a
desp
.
trans_b
=
p
.
trans_b
desp
.
trans_c
=
p
.
trans_c
desp
.
tile_shape
=
(
p
.
ts
[
0
],
p
.
ts
[
1
],
p
.
ts
[
2
])
desp
.
warp_tile_shape
=
(
p
.
wts
[
0
],
p
.
wts
[
1
],
p
.
wts
[
2
])
if
p
.
tensorop
is
not
None
:
desp
.
tensorop
=
(
p
.
tensorop
[
0
],
p
.
tensorop
[
1
],
p
.
tensorop
[
2
])
desp
.
num_stage
=
p
.
num_stage
desp
.
algo
=
p
.
algo
.
value
desp
.
split_k_serial
=
p
.
splitk_serial
desp
.
split_k_parallel
=
p
.
splitk_parallel
desp
.
shuffle_type
=
ShuffleStrideTypeCpp
(
p
.
shuffle_stride
.
value
)
desp
.
access_per_vector
=
p
.
access_per_vector
desp
.
is_nvrtc
=
p
.
is_nvrtc
def
get_gemm_algo_desp_from_param
(
p
:
GemmAlgoParams
):
desp
=
GemmAlgoDesp
()
_assign_gemm_desp_props
(
desp
,
p
)
return
desp
def
get_conv_algo_desp_from_param
(
p
:
ConvAlgoParams
):
desp
=
ConvAlgoDesp
(
p
.
ndim
,
ConvOpTypeCpp
(
p
.
op_type
.
value
))
_assign_gemm_desp_props
(
desp
,
p
)
# conv attrs
desp
.
ndim
=
p
.
ndim
desp
.
op_type
=
ConvOpTypeCpp
(
p
.
op_type
.
value
)
desp
.
iter_algo
=
ConvIterAlgoCpp
(
p
.
iter_algo
.
value
)
desp
.
layout_i
=
ConvLayoutTypeCpp
(
p
.
layout_desp_input
.
layout_type
.
value
)
desp
.
layout_w
=
ConvLayoutTypeCpp
(
p
.
layout_desp_weight
.
layout_type
.
value
)
desp
.
layout_o
=
ConvLayoutTypeCpp
(
p
.
layout_desp_output
.
layout_type
.
value
)
desp
.
interleave_i
=
p
.
layout_desp_input
.
interleave
desp
.
interleave_w
=
p
.
layout_desp_weight
.
interleave
desp
.
interleave_o
=
p
.
layout_desp_output
.
interleave
desp
.
mask_sparse
=
p
.
mask_sparse
desp
.
increment_k_first
=
p
.
increment_k_first
return
desp
def
_assign_gemm_params
(
desp
:
Union
[
ConvAlgoDesp
,
GemmAlgoDesp
],
p
:
Union
[
GemmAlgoParams
,
ConvAlgoParams
]):
p
.
dtype_a
=
dtypes
.
get_dtype_from_tvdtype
(
desp
.
dtype_a
)
p
.
dtype_b
=
dtypes
.
get_dtype_from_tvdtype
(
desp
.
dtype_b
)
p
.
dtype_c
=
dtypes
.
get_dtype_from_tvdtype
(
desp
.
dtype_c
)
p
.
dtype_acc
=
dtypes
.
get_dtype_from_tvdtype
(
desp
.
dacc
)
p
.
dtype_comp
=
dtypes
.
get_dtype_from_tvdtype
(
desp
.
dcomp
)
p
.
trans_a
=
desp
.
trans_a
p
.
trans_b
=
desp
.
trans_b
p
.
trans_c
=
desp
.
trans_c
p
.
ts
=
MetaArray
(
*
desp
.
tile_shape
)
p
.
wts
=
MetaArray
(
*
desp
.
warp_tile_shape
)
if
desp
.
tensorop
[
0
]
>
0
:
p
.
tensorop
=
TensorOp
(
(
desp
.
tensorop
[
0
],
desp
.
tensorop
[
1
],
desp
.
tensorop
[
2
]))
p
.
num_stage
=
desp
.
num_stage
p
.
algo
=
GemmAlgo
(
desp
.
algo
)
p
.
splitk_serial
=
desp
.
split_k_serial
p
.
splitk_parallel
=
desp
.
split_k_parallel
p
.
shuffle_stride
=
ShuffleStrideType
(
desp
.
shuffle_type
.
value
)
p
.
access_per_vector
=
desp
.
access_per_vector
p
.
is_nvrtc
=
desp
.
is_nvrtc
def
get_gemm_param_from_desp
(
desp
:
GemmAlgoDesp
):
p
=
GemmAlgoParams
((
0
,
0
,
0
),
(
0
,
0
,
0
),
0
,
"s8,s8,s8,s8,s8"
,
False
,
False
,
False
,
GemmAlgo
.
Simt
)
_assign_gemm_params
(
desp
,
p
)
return
p
def
get_conv_param_from_desp
(
desp
:
ConvAlgoDesp
):
p
=
ConvAlgoParams
(
desp
.
ndim
,
ConvOpType
.
kForward
,
ConvIterAlgo
.
Optimized
,
(
0
,
0
,
0
),
(
0
,
0
,
0
),
0
,
"s8,s8,s8,s8,s8"
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Simt
)
_assign_gemm_params
(
desp
,
p
)
# conv attrs
p
.
ndim
=
desp
.
ndim
p
.
op_type
=
ConvOpType
(
desp
.
op_type
.
value
)
p
.
iter_algo
=
ConvIterAlgo
(
desp
.
iter_algo
.
value
)
p
.
layout_desp_input
=
ConvLayout
(
ConvLayoutType
(
desp
.
layout_i
.
value
),
desp
.
interleave_i
)
p
.
layout_desp_weight
=
ConvLayout
(
ConvLayoutType
(
desp
.
layout_w
.
value
),
desp
.
interleave_w
)
p
.
layout_desp_output
=
ConvLayout
(
ConvLayoutType
(
desp
.
layout_o
.
value
),
desp
.
interleave_o
)
p
.
mask_sparse
=
desp
.
mask_sparse
p
.
increment_k_first
=
desp
.
increment_k_first
return
p
spconv/build.py
View file @
7af751dc
...
@@ -29,21 +29,20 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable(
...
@@ -29,21 +29,20 @@ if project_is_installed(PACKAGE_NAME) and project_is_editable(
from
cumm.common
import
CompileInfo
from
cumm.common
import
CompileInfo
from
spconv.csrc.sparse.all
import
SpconvOps
from
spconv.csrc.sparse.all
import
SpconvOps
from
spconv.csrc.sparse.alloc
import
ExternalAllocator
from
spconv.csrc.utils
import
BoxOps
from
spconv.csrc.utils
import
BoxOps
from
spconv.csrc.hash.core
import
HashTable
from
spconv.csrc.hash.core
import
HashTable
all_shuffle
=
SHUFFLE_SIMT_PARAMS
+
SHUFFLE_VOLTA_PARAMS
+
SHUFFLE_TURING_PARAMS
cu
=
GemmMainUnitTest
(
SHUFFLE_SIMT_PARAMS
+
SHUFFLE_VOLTA_PARAMS
+
all_shuffle
=
list
(
filter
(
lambda
x
:
not
x
.
is_nvrtc
,
all_shuffle
))
SHUFFLE_TURING_PARAMS
)
cu
=
GemmMainUnitTest
(
all_shuffle
)
cu
.
namespace
=
"cumm.gemm.main"
cu
.
namespace
=
"cumm.gemm.main"
convcu
=
ConvMainUnitTest
(
IMPLGEMM_SIMT_PARAMS
+
IMPLGEMM_VOLTA_PARAMS
+
all_imp
=
(
IMPLGEMM_SIMT_PARAMS
+
IMPLGEMM_VOLTA_PARAMS
+
IMPLGEMM_TURING_PARAMS
)
IMPLGEMM_TURING_PARAMS
)
all_imp
=
list
(
filter
(
lambda
x
:
not
x
.
is_nvrtc
,
all_imp
))
convcu
=
ConvMainUnitTest
(
all_imp
)
convcu
.
namespace
=
"cumm.conv.main"
convcu
.
namespace
=
"cumm.conv.main"
objects_folder
=
None
pccm
.
builder
.
build_pybind
([
cu
,
convcu
,
SpconvOps
(),
BoxOps
(),
HashTable
(),
CompileInfo
(),
ExternalAllocator
()],
if
InWindows
:
# windows have command line limit, so we use objects_folder to reduce command size.
objects_folder
=
"objects"
pccm
.
builder
.
build_pybind
([
cu
,
convcu
,
SpconvOps
(),
BoxOps
(),
HashTable
(),
CompileInfo
()],
PACKAGE_ROOT
/
"core_cc"
,
PACKAGE_ROOT
/
"core_cc"
,
namespace_root
=
PACKAGE_ROOT
,
namespace_root
=
PACKAGE_ROOT
,
objects_folder
=
objects_folder
,
load_library
=
False
)
load_library
=
False
)
spconv/constants.py
View file @
7af751dc
...
@@ -16,6 +16,8 @@ import os
...
@@ -16,6 +16,8 @@ import os
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
List
from
typing
import
List
from
pccm.utils
import
project_is_editable
,
project_is_installed
from
pccm.utils
import
project_is_editable
,
project_is_installed
from
cumm.gemm.constants
import
NVRTCMode
import
enum
PACKAGE_NAME
=
"spconv"
PACKAGE_NAME
=
"spconv"
PACKAGE_ROOT
=
Path
(
__file__
).
parent
.
resolve
()
PACKAGE_ROOT
=
Path
(
__file__
).
parent
.
resolve
()
...
@@ -43,3 +45,21 @@ else:
...
@@ -43,3 +45,21 @@ else:
# for f16 backward weight, larger splitk, larger compute error.
# for f16 backward weight, larger splitk, larger compute error.
# so we use this env to control maximum splitk.
# so we use this env to control maximum splitk.
SPCONV_BWD_SPLITK
=
list
(
map
(
int
,
os
.
getenv
(
"SPCONV_BWD_SPLITK"
,
"1,2,4,8,16,32,64"
).
split
(
","
)))
SPCONV_BWD_SPLITK
=
list
(
map
(
int
,
os
.
getenv
(
"SPCONV_BWD_SPLITK"
,
"1,2,4,8,16,32,64"
).
split
(
","
)))
SPCONV_NVRTC_MODE
=
NVRTCMode
.
ConstantMemory
SPCONV_DEBUG_NVRTC_KERNELS
=
False
class
SpconvAllocatorKeys
:
Pair
=
"Pair"
IndiceNumPerLoc
=
"IndiceNumPerLoc"
PairMask
=
"PairMask"
MaskArgSort
=
"MaskArgSort"
OutIndices
=
"OutIndices"
PairFwd
=
"PairFwd"
# PairMaskFwd = "PairMaskFwd"
PairMaskBwd
=
"PairMaskBwd"
# MaskArgSortFwd = "MaskArgSortFwd"
MaskArgSortBwd
=
"MaskArgSortBwd"
OutFeatures
=
"OutFeatures"
spconv/core.py
View file @
7af751dc
...
@@ -15,17 +15,17 @@ from enum import Enum
...
@@ -15,17 +15,17 @@ from enum import Enum
from
cumm.gemm.main
import
gen_shuffle_params_v2
as
gen_shuffle_params
,
GemmAlgoParams
from
cumm.gemm.main
import
gen_shuffle_params_v2
as
gen_shuffle_params
,
GemmAlgoParams
from
cumm.gemm
import
kernel
from
cumm.gemm
import
kernel
from
typing
import
List
from
typing
import
List
from
cumm.gemm.algospec.core
import
TensorOp
Params
from
cumm.gemm.algospec.core
import
TensorOp
from
cumm.conv.main
import
gen_gemm_params
as
gen_conv_params
,
ConvFwdAndBwdInput
,
ConvBwdWeight
,
ConvIterAlgo
,
GemmAlgo
from
cumm.conv.main
import
gen_gemm_params
as
gen_conv_params
,
ConvFwdAndBwdInput
,
ConvBwdWeight
,
ConvIterAlgo
,
GemmAlgo
from
cumm.conv.bases
import
(
NCHW
,
NHWC
,
ConvEnum
,
ConvIterAlgo
,
ConvLayout
,
from
cumm.conv.bases
import
(
NCHW
,
NHWC
,
ConvIterAlgo
,
ConvLayout
,
ConvLayoutType
,
ConvMode
,
ConvOpType
)
ConvLayoutType
,
ConvMode
,
ConvOpType
)
from
spconv.constants
import
NDIM_DONT_CARE
from
spconv.constants
import
NDIM_DONT_CARE
class
ConvAlgo
(
Enum
):
class
ConvAlgo
(
Enum
):
Native
=
"Native"
Native
=
0
MaskImplicitGemm
=
"MaskImplicitGemm"
MaskImplicitGemm
=
1
MaskSplitImplicitGemm
=
"MaskSplitImplicitGemm"
MaskSplitImplicitGemm
=
2
class
AlgoHint
(
Enum
):
class
AlgoHint
(
Enum
):
...
@@ -40,17 +40,17 @@ class AlgoHint(Enum):
...
@@ -40,17 +40,17 @@ class AlgoHint(Enum):
SHUFFLE_SIMT_PARAMS
:
List
[
GemmAlgoParams
]
=
[
SHUFFLE_SIMT_PARAMS
:
List
[
GemmAlgoParams
]
=
[
*
gen_shuffle_params
((
64
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
*
gen_shuffle_params
((
64
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
),
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
,
is_nvrtc
=
True
),
*
gen_shuffle_params
((
128
,
64
,
32
),
(
64
,
32
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
*
gen_shuffle_params
((
128
,
64
,
32
),
(
64
,
32
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
),
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
,
is_nvrtc
=
True
),
*
gen_shuffle_params
((
128
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
*
gen_shuffle_params
((
128
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
),
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
,
is_nvrtc
=
True
),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
128
,
128
,
32
),
(
128
,
128
,
32
),
(
64
,
32
,
32
),
[
"s8,s8,s8,s32,s32"
,
"s8,s8,s32,s32,s32"
],
""
,
2
,
(
64
,
32
,
32
),
[
"s8,s8,s8,s32,s32"
,
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
),
kernel
.
GemmAlgo
.
SimtDP4A
,
None
,
is_nvrtc
=
True
),
*
gen_shuffle_params
((
64
,
64
,
32
),
(
32
,
32
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
*
gen_shuffle_params
((
64
,
64
,
32
),
(
32
,
32
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
),
2
,
kernel
.
GemmAlgo
.
SimtDP4A
,
None
,
is_nvrtc
=
True
),
*
gen_shuffle_params
((
64
,
256
,
8
),
(
32
,
64
,
8
),
[
"f32,f32,f32,f32,f32"
],
*
gen_shuffle_params
((
64
,
256
,
8
),
(
32
,
64
,
8
),
[
"f32,f32,f32,f32,f32"
],
"f32,f32,f32,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Simt
,
None
),
"f32,f32,f32,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Simt
,
None
),
# *gen_shuffle_params(
# *gen_shuffle_params(
...
@@ -104,88 +104,88 @@ SHUFFLE_VOLTA_PARAMS: List[GemmAlgoParams] = [
...
@@ -104,88 +104,88 @@ SHUFFLE_VOLTA_PARAMS: List[GemmAlgoParams] = [
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
64
,
64
,
32
),
(
64
,
64
,
32
),
(
32
,
32
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
(
32
,
32
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Volta
,
TensorOp
Params
((
8
,
8
,
4
))),
kernel
.
GemmAlgo
.
Volta
,
TensorOp
((
8
,
8
,
4
))),
# *gen_shuffle_params(
# *gen_shuffle_params(
# (128, 128, 32),
# (128, 128, 32),
# (64, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
# (64, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
# kernel.GemmAlgo.Volta, TensorOp
Params
((8, 8, 4))),
# kernel.GemmAlgo.Volta, TensorOp((8, 8, 4))),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
128
,
256
,
32
),
(
128
,
256
,
32
),
(
64
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
(
64
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Volta
,
TensorOp
Params
((
8
,
8
,
4
))),
kernel
.
GemmAlgo
.
Volta
,
TensorOp
((
8
,
8
,
4
))),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
256
,
128
,
32
),
(
256
,
128
,
32
),
(
64
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
(
64
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Volta
,
TensorOp
Params
((
8
,
8
,
4
))),
kernel
.
GemmAlgo
.
Volta
,
TensorOp
((
8
,
8
,
4
))),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
128
,
64
,
32
),
(
128
,
64
,
32
),
(
64
,
32
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
(
64
,
32
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Volta
,
TensorOp
Params
((
8
,
8
,
4
))),
kernel
.
GemmAlgo
.
Volta
,
TensorOp
((
8
,
8
,
4
))),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
64
,
128
,
32
),
(
64
,
128
,
32
),
(
32
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
(
32
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Volta
,
TensorOp
Params
((
8
,
8
,
4
))),
kernel
.
GemmAlgo
.
Volta
,
TensorOp
((
8
,
8
,
4
))),
]
]
# SHUFFLE_VOLTA_PARAMS = []
# SHUFFLE_VOLTA_PARAMS = []
SHUFFLE_TURING_PARAMS
:
List
[
GemmAlgoParams
]
=
[
SHUFFLE_TURING_PARAMS
:
List
[
GemmAlgoParams
]
=
[
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
64
,
64
,
32
),
(
64
,
64
,
32
),
(
32
,
32
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
(
32
,
32
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
))),
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
8
))),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
128
,
128
,
32
),
(
128
,
128
,
32
),
(
32
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
(
32
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
))),
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
8
))),
# *gen_shuffle_params(
# *gen_shuffle_params(
# (128, 128, 32),
# (128, 128, 32),
# (64, 32, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
# (64, 32, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2,
# kernel.GemmAlgo.Turing, TensorOp
Params
((16, 8, 8))),
# kernel.GemmAlgo.Turing, TensorOp((16, 8, 8))),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
64
,
64
,
64
),
(
64
,
64
,
64
),
(
32
,
32
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
(
32
,
32
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
))),
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
8
))),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
64
,
128
,
64
),
(
64
,
128
,
64
),
(
32
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
(
32
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
))),
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
8
))),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
128
,
256
,
32
),
(
128
,
256
,
32
),
(
64
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
(
64
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
))),
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
8
))),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
256
,
128
,
32
),
(
256
,
128
,
32
),
(
64
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
(
64
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
))),
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
8
))),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
128
,
64
,
32
),
(
128
,
64
,
32
),
(
64
,
32
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
(
64
,
32
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
))),
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
8
))),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
64
,
128
,
32
),
(
64
,
128
,
32
),
(
32
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
(
32
,
64
,
32
),
[
"f16,f16,f16,f16,f16"
],
"f16,f16,f16,f32,f32"
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
))),
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
16
,
8
,
8
))),
*
gen_shuffle_params
((
64
,
64
,
32
),
(
32
,
32
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
*
gen_shuffle_params
((
64
,
64
,
32
),
(
32
,
32
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
Params
((
8
,
8
,
16
))),
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))
,
is_nvrtc
=
True
),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
128
,
128
,
32
),
(
128
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
(
32
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
Params
((
8
,
8
,
16
))),
TensorOp
((
8
,
8
,
16
))
,
is_nvrtc
=
True
),
# *gen_shuffle_params(
# *gen_shuffle_params(
# (128, 128, 32),
# (128, 128, 32),
# (64, 32, 32), ["s8,s8,s8,s32,s32", "s8,s8,s32,s32,s32"], "", 2,
# (64, 32, 32), ["s8,s8,s8,s32,s32", "s8,s8,s32,s32,s32"], "", 2,
# kernel.GemmAlgo.Turing, TensorOp
Params
((8, 8, 16))),
# kernel.GemmAlgo.Turing, TensorOp((8, 8, 16))),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
128
,
256
,
32
),
(
128
,
256
,
32
),
(
64
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
(
64
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
Params
((
8
,
8
,
16
))),
TensorOp
((
8
,
8
,
16
))
,
is_nvrtc
=
True
),
*
gen_shuffle_params
(
*
gen_shuffle_params
(
(
256
,
128
,
32
),
(
256
,
128
,
32
),
(
64
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
(
64
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
Params
((
8
,
8
,
16
))),
TensorOp
((
8
,
8
,
16
))
,
is_nvrtc
=
True
),
*
gen_shuffle_params
((
128
,
64
,
32
),
(
64
,
32
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
*
gen_shuffle_params
((
128
,
64
,
32
),
(
64
,
32
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
Params
((
8
,
8
,
16
))),
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))
,
is_nvrtc
=
True
),
*
gen_shuffle_params
((
64
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
*
gen_shuffle_params
((
64
,
128
,
32
),
(
32
,
64
,
32
),
[
"s8,s8,s32,s32,s32"
],
""
,
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
Params
((
8
,
8
,
16
))),
2
,
kernel
.
GemmAlgo
.
Turing
,
TensorOp
((
8
,
8
,
16
))
,
is_nvrtc
=
True
),
]
]
# SHUFFLE_TURING_PARAMS = []
# SHUFFLE_TURING_PARAMS = []
...
@@ -399,6 +399,34 @@ IMPLGEMM_SIMT_PARAMS = [
...
@@ -399,6 +399,34 @@ IMPLGEMM_SIMT_PARAMS = [
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
]
]
IMPLGEMM_SIMT_PARAMS
=
[
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
32
,
16
),
(
32
,
32
,
8
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"f32,f32,f32,f32,f32"
,
"f16,f16,f16,f32,f32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Simt
,
None
,
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
*
gen_conv_params
(
ConvBwdWeight
,
(
64
,
32
,
16
),
(
32
,
32
,
8
),
NDIM_DONT_CARE
,
ConvIterAlgo
.
Optimized
,
2
,
[
"f32,f32,f32,f32,f32"
,
"f16,f16,f16,f32,f32"
],
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Simt
,
None
,
mask_sparse
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
]
IMPLGEMM_VOLTA_PARAMS
=
[
IMPLGEMM_VOLTA_PARAMS
=
[
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
32
),
(
32
,
32
,
32
),
*
gen_conv_params
(
ConvFwdAndBwdInput
,
(
64
,
64
,
32
),
(
32
,
32
,
32
),
NDIM_DONT_CARE
,
NDIM_DONT_CARE
,
...
@@ -408,7 +436,7 @@ IMPLGEMM_VOLTA_PARAMS = [
...
@@ -408,7 +436,7 @@ IMPLGEMM_VOLTA_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Volta
,
GemmAlgo
.
Volta
,
TensorOp
Params
((
8
,
8
,
4
)),
TensorOp
((
8
,
8
,
4
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -420,7 +448,7 @@ IMPLGEMM_VOLTA_PARAMS = [
...
@@ -420,7 +448,7 @@ IMPLGEMM_VOLTA_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Volta
,
GemmAlgo
.
Volta
,
TensorOp
Params
((
8
,
8
,
4
)),
TensorOp
((
8
,
8
,
4
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
0
),
access_per_vector
=
0
),
...
@@ -432,7 +460,7 @@ IMPLGEMM_VOLTA_PARAMS = [
...
@@ -432,7 +460,7 @@ IMPLGEMM_VOLTA_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Volta
,
GemmAlgo
.
Volta
,
TensorOp
Params
((
8
,
8
,
4
)),
TensorOp
((
8
,
8
,
4
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -444,7 +472,7 @@ IMPLGEMM_VOLTA_PARAMS = [
...
@@ -444,7 +472,7 @@ IMPLGEMM_VOLTA_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Volta
,
GemmAlgo
.
Volta
,
TensorOp
Params
((
8
,
8
,
4
)),
TensorOp
((
8
,
8
,
4
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -456,7 +484,7 @@ IMPLGEMM_VOLTA_PARAMS = [
...
@@ -456,7 +484,7 @@ IMPLGEMM_VOLTA_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Volta
,
GemmAlgo
.
Volta
,
TensorOp
Params
((
8
,
8
,
4
)),
TensorOp
((
8
,
8
,
4
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -468,7 +496,7 @@ IMPLGEMM_VOLTA_PARAMS = [
...
@@ -468,7 +496,7 @@ IMPLGEMM_VOLTA_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Volta
,
GemmAlgo
.
Volta
,
TensorOp
Params
((
8
,
8
,
4
)),
TensorOp
((
8
,
8
,
4
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
0
),
access_per_vector
=
0
),
...
@@ -480,7 +508,7 @@ IMPLGEMM_VOLTA_PARAMS = [
...
@@ -480,7 +508,7 @@ IMPLGEMM_VOLTA_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Volta
,
GemmAlgo
.
Volta
,
TensorOp
Params
((
8
,
8
,
4
)),
TensorOp
((
8
,
8
,
4
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -495,7 +523,7 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -495,7 +523,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
)),
TensorOp
((
16
,
8
,
8
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
0
),
access_per_vector
=
0
),
...
@@ -507,7 +535,7 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -507,7 +535,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
)),
TensorOp
((
16
,
8
,
8
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -519,7 +547,7 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -519,7 +547,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
)),
TensorOp
((
16
,
8
,
8
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -531,7 +559,7 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -531,7 +559,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
)),
TensorOp
((
16
,
8
,
8
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -543,7 +571,7 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -543,7 +571,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
)),
TensorOp
((
16
,
8
,
8
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -555,7 +583,7 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -555,7 +583,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
)),
TensorOp
((
16
,
8
,
8
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -567,7 +595,7 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -567,7 +595,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
)),
TensorOp
((
16
,
8
,
8
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -579,7 +607,7 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -579,7 +607,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
)),
TensorOp
((
16
,
8
,
8
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -591,7 +619,7 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -591,7 +619,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
)),
TensorOp
((
16
,
8
,
8
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -603,7 +631,7 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -603,7 +631,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
)),
TensorOp
((
16
,
8
,
8
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -615,7 +643,7 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -615,7 +643,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
)),
TensorOp
((
16
,
8
,
8
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -628,7 +656,7 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -628,7 +656,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
)),
TensorOp
((
16
,
8
,
8
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
0
),
access_per_vector
=
0
),
...
@@ -641,7 +669,7 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -641,7 +669,7 @@ IMPLGEMM_TURING_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
)),
TensorOp
((
16
,
8
,
8
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
...
@@ -654,12 +682,16 @@ IMPLGEMM_TURING_PARAMS = [
...
@@ -654,12 +682,16 @@ IMPLGEMM_TURING_PARAMS = [
NHWC
,
NHWC
,
NHWC
,
NHWC
,
GemmAlgo
.
Turing
,
GemmAlgo
.
Turing
,
TensorOp
Params
((
16
,
8
,
8
)),
TensorOp
((
16
,
8
,
8
)),
mask_sparse
=
True
,
mask_sparse
=
True
,
increment_k_first
=
True
,
increment_k_first
=
True
,
access_per_vector
=
1
),
access_per_vector
=
1
),
# *gen_conv_params(ConvBwdWeight, (32, 64, 32), (32, 32, 16), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, "f16,f16,f16,f32,f32",
# *gen_conv_params(ConvBwdWeight, (32, 64, 32), (32, 32, 16), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, "f16,f16,f16,f32,f32",
# NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOp
Params
((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1),
# NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOp((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1),
# gen_conv_params(ConvFwdAndBwdInput, )
# gen_conv_params(ConvFwdAndBwdInput, )
]
]
ALL_NATIVE_PARAMS
=
SHUFFLE_SIMT_PARAMS
+
SHUFFLE_TURING_PARAMS
+
SHUFFLE_VOLTA_PARAMS
ALL_IMPGEMM_PARAMS
=
IMPLGEMM_SIMT_PARAMS
+
IMPLGEMM_TURING_PARAMS
+
IMPLGEMM_VOLTA_PARAMS
spconv/core_cc/csrc/sparse/all/__init__.pyi
View file @
7af751dc
...
@@ -34,13 +34,15 @@ class SpconvOps:
...
@@ -34,13 +34,15 @@ class SpconvOps:
"""
"""
...
...
@staticmethod
@staticmethod
def generate_conv_inds_stage2(indices: Tensor, hashdata: Tensor, indice_pairs: Tensor, indice_pairs_uniq: Tensor, out_inds: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
def generate_conv_inds_stage2(indices: Tensor, hashdata
_k: Tensor, hashdata_v
: Tensor, indice_pairs: Tensor, indice_pairs_uniq: Tensor,
indice_pairs_uniq_before_sort: Tensor,
out_inds: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
"""
"""
Args:
Args:
indices:
indices:
hashdata:
hashdata_k:
hashdata_v:
indice_pairs:
indice_pairs:
indice_pairs_uniq:
indice_pairs_uniq:
indice_pairs_uniq_before_sort:
out_inds:
out_inds:
num_out_act:
num_out_act:
batch_size:
batch_size:
...
@@ -74,14 +76,16 @@ class SpconvOps:
...
@@ -74,14 +76,16 @@ class SpconvOps:
"""
"""
...
...
@staticmethod
@staticmethod
def generate_conv_inds_mask_stage2(indices: Tensor, hashdata: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor, out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
def generate_conv_inds_mask_stage2(indices: Tensor, hashdata
_k: Tensor, hashdata_v
: Tensor, indice_pairs_fwd: Tensor, indice_pairs_bwd: Tensor, indice_pairs_uniq: Tensor,
indice_pairs_uniq_before_sort: Tensor,
out_inds: Tensor, mask_fwd: Tensor, mask_bwd: Tensor, num_out_act: int, batch_size: int, output_dims: List[int], input_dims: List[int], ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool = False, stream_int: int = 0) -> int:
"""
"""
Args:
Args:
indices:
indices:
hashdata:
hashdata_k:
hashdata_v:
indice_pairs_fwd:
indice_pairs_fwd:
indice_pairs_bwd:
indice_pairs_bwd:
indice_pairs_uniq:
indice_pairs_uniq:
indice_pairs_uniq_before_sort:
out_inds:
out_inds:
mask_fwd:
mask_fwd:
mask_bwd:
mask_bwd:
...
@@ -98,11 +102,12 @@ class SpconvOps:
...
@@ -98,11 +102,12 @@ class SpconvOps:
"""
"""
...
...
@staticmethod
@staticmethod
def generate_subm_conv_inds(indices: Tensor, hashdata: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0) -> int:
def generate_subm_conv_inds(indices: Tensor, hashdata
_k: Tensor, hashdata_v
: Tensor, indice_pairs: Tensor, out_inds: Tensor, indice_num_per_loc: Tensor, batch_size: int, input_dims: List[int], ksize: List[int], dilation: List[int], indice_pair_mask: Tensor = Tensor(), backward: bool = False, stream_int: int = 0) -> int:
"""
"""
Args:
Args:
indices:
indices:
hashdata:
hashdata_k:
hashdata_v:
indice_pairs:
indice_pairs:
out_inds:
out_inds:
indice_num_per_loc:
indice_num_per_loc:
...
@@ -276,6 +281,18 @@ class SpconvOps:
...
@@ -276,6 +281,18 @@ class SpconvOps:
"""
"""
...
...
@staticmethod
@staticmethod
def sort_1d_by_key_split_allocator_v2(data: Tensor, allocator, mask: Tensor, indices: Tensor = Tensor(), stream: int = 0, mask_output: bool = False) -> Tensor:
"""
Args:
data:
allocator:
mask:
indices:
stream:
mask_output:
"""
...
@staticmethod
def count_bits(a: Tensor) -> Tensor:
def count_bits(a: Tensor) -> Tensor:
"""
"""
Args:
Args:
...
@@ -328,3 +345,51 @@ class SpconvOps:
...
@@ -328,3 +345,51 @@ class SpconvOps:
stream_int:
stream_int:
"""
"""
...
...
@staticmethod
def get_int32_max() -> int: ...
@staticmethod
def get_indice_pairs_implicit_gemm(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, is_train: bool, stream_int: int = 0) -> Tensor:
"""
Args:
allocator:
indices:
batch_size:
input_dims:
algo:
ksize:
stride:
padding:
dilation:
out_padding:
subm:
transposed:
is_train:
stream_int:
"""
...
@staticmethod
def get_indice_pairs(allocator, indices: Tensor, batch_size: int, input_dims: List[int], algo: int, ksize: List[int], stride: List[int], padding: List[int], dilation: List[int], out_padding: List[int], subm: bool, transposed: bool, stream_int: int = 0) -> None:
"""
Args:
allocator:
indices:
batch_size:
input_dims:
algo:
ksize:
stride:
padding:
dilation:
out_padding:
subm:
transposed:
stream_int:
"""
...
@staticmethod
def test_allocator(allocator) -> None:
"""
Args:
allocator:
"""
...
spconv/core_cc/csrc/sparse/alloc.pyi
0 → 100644
View file @
7af751dc
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview import Tensor
class ExternalAllocator:
def zeros(self, name: str, shape: List[int], dtype: int, device: int) -> Tensor:
"""
Args:
name:
shape:
dtype:
device:
"""
...
def empty(self, name: str, shape: List[int], dtype: int, device: int) -> Tensor:
"""
Args:
name:
shape:
dtype:
device:
"""
...
def full_int(self, name: str, shape: List[int], value: int, dtype: int, device: int) -> Tensor:
"""
Args:
name:
shape:
value:
dtype:
device:
"""
...
def full_float(self, name: str, shape: List[int], value: float, dtype: int, device: int) -> Tensor:
"""
Args:
name:
shape:
value:
dtype:
device:
"""
...
def free(self, ten: Tensor) -> None:
"""
Args:
ten:
"""
...
def free_noexcept(self, ten: Tensor) -> None:
"""
Args:
ten:
"""
...
spconv/core_cc/cumm/__init__.pyi
View file @
7af751dc
# Copyright 2022 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
spconv/core_cc/cumm/conv/main.pyi
View file @
7af751dc
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue
from ...cumm.gemm.main import GemmAlgoDesp
from cumm.tensorview.gemm import ConvParams
from cumm.tensorview import Tensor
from cumm.tensorview import CUDAKernelTimer
class ConvAlgoDesp(GemmAlgoDesp):
ndim: int
op_type: int
iter_algo: int
layout_i: int
layout_w: int
layout_o: int
interleave_i: int
interleave_w: int
interleave_o: int
mask_sparse: bool
increment_k_first: bool
def __init__(self, ndim: int, op_type: int) -> None:
"""
Args:
ndim:
op_type:
"""
...
def __repr__(self) -> str: ...
@staticmethod
def conv_iwo_012_to_abc(op_type: int) -> List[int]:
"""
Args:
op_type:
"""
...
@staticmethod
def gemm_abc_012_to_iwo(op_type: int) -> List[int]:
"""
Args:
op_type:
"""
...
@property
def dtype_input(self) -> int: ...
@property
def dtype_weight(self) -> int: ...
@property
def dtype_output(self) -> int: ...
def supported(self, m: int, n: int, k: int, C: int, K: int, mask_width: int) -> bool:
"""
Args:
m:
n:
k:
C:
K:
mask_width:
"""
...
def query_conv_workspace_size(self, m: int, n: int, k: int, split_k_slices: int, kv: int) -> int:
"""
Args:
m:
n:
k:
split_k_slices:
kv:
"""
...
def supported_ldx_conv(self, ldi: int, ldw: int, ldo: int) -> bool:
"""
Args:
ldi:
ldw:
ldo:
"""
...
class ConvParams:
conv_algo_desp: Any
input: Tensor
weight: Tensor
output: Tensor
split_k_slices: int
padding: List[int]
stride: List[int]
dilation: List[int]
alpha: float
beta: float
mask_width: int
mask_filter: int
reverse_mask: bool
verbose: bool
timer: CUDAKernelTimer
workspace: Tensor = Tensor()
mask: Tensor = Tensor()
mask_argsort: Tensor = Tensor()
indices: Tensor = Tensor()
mask_output: Tensor = Tensor()
stream: int
def __init__(self, ndim: int, op_type: int, timer: CUDAKernelTimer = CUDAKernelTimer(False)) -> None:
"""
Args:
ndim:
op_type:
timer:
"""
...
class ConvMainUnitTest:
class ConvMainUnitTest:
@staticmethod
@staticmethod
def extract_mnk(op_type: int, N: int, C: int, K: int, kernel_volume: int, in_prod: int, out_prod: int, mask_sparse: bool) -> List[int]:
def extract_mnk(op_type: int, N: int, C: int, K: int, kernel_volume: int, in_prod: int, out_prod: int, mask_sparse: bool) -> List[int]:
...
...
spconv/core_cc/cumm/gemm/main.pyi
View file @
7af751dc
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview import Tensor
from cumm.tensorview.gemm import GemmParams
from cumm.tensorview import CUDAKernelTimer
class GemmAlgoDesp:
dtype_a: int
dtype_b: int
dtype_c: int
tile_shape: Tuple[int, int, int]
warp_tile_shape: Tuple[int, int, int]
num_stage: int
dacc: int
dcomp: int
algo: str
tensorop: List[int]
split_k_serial_: int
split_k_parallel_: int
shuffle_type: str
element_per_access_a: int
element_per_access_b: int
element_per_access_c: int
access_per_vector: int
def __init__(self) -> None: ...
def __repr__(self) -> str: ...
@property
def split_k_serial(self) -> bool: ...
@split_k_serial.setter
def split_k_serial(self, val: bool) -> None:
"""
Args:
val:
"""
...
@property
def split_k_parallel(self) -> bool: ...
@split_k_parallel.setter
def split_k_parallel(self, val: bool) -> None:
"""
Args:
val:
"""
...
def check_valid(self) -> None: ...
@property
def trans_a(self) -> bool: ...
@trans_a.setter
def trans_a(self, val: bool) -> None:
"""
Args:
val:
"""
...
@property
def trans_b(self) -> bool: ...
@trans_b.setter
def trans_b(self, val: bool) -> None:
"""
Args:
val:
"""
...
@property
def trans_c(self) -> bool: ...
@trans_c.setter
def trans_c(self, val: bool) -> None:
"""
Args:
val:
"""
...
def query_workspace_size(self, m: int, n: int, k: int, split_k_slices: int) -> int:
"""
Args:
m:
n:
k:
split_k_slices:
"""
...
def supported(self, m: int, n: int, k: int) -> bool:
"""
Args:
m:
n:
k:
"""
...
def supported_ldx(self, lda: int, ldb: int, ldc: int) -> bool:
"""
Args:
lda:
ldb:
ldc:
"""
...
class GemmParams:
algo_desp: GemmAlgoDesp
split_k_slices: int
workspace: Tensor = Tensor()
a_inds: Tensor = Tensor()
b_inds: Tensor = Tensor()
c_inds: Tensor = Tensor()
alpha: float
beta: float
stream: int
timer: CUDAKernelTimer
def __init__(self, timer: CUDAKernelTimer = CUDAKernelTimer(False)) -> None:
"""
Args:
timer:
"""
...
def check_valid(self) -> None: ...
@property
def a(self) -> Tensor: ...
@a.setter
def a(self, val: Tensor) -> None:
"""
Args:
val:
"""
...
@property
def b(self) -> Tensor: ...
@b.setter
def b(self, val: Tensor) -> None:
"""
Args:
val:
"""
...
@property
def c(self) -> Tensor: ...
@c.setter
def c(self, val: Tensor) -> None:
"""
Args:
val:
"""
...
class GemmMainUnitTest:
class GemmMainUnitTest:
@staticmethod
@staticmethod
def get_all_algo_desp() -> List[
GemmAlgoDesp
]: ...
def get_all_algo_desp() -> List[
Any
]: ...
@staticmethod
@staticmethod
def extract_mnk(a_shape: List[int], b_shape: List[int], trans_a: bool, trans_b: bool, trans_c: bool, shuffle_type: str = "
NS
", a_inds_shape: List[int] = [], b_inds_shape: List[int] = [], c_inds_shape: List[int] = []) -> Tuple[int, int, int]:
def extract_mnk(a_shape: List[int], b_shape: List[int], trans_a: bool, trans_b: bool, trans_c: bool, shuffle_type: str = "
0
", a_inds_shape: List[int] = [], b_inds_shape: List[int] = [], c_inds_shape: List[int] = []) -> Tuple[int, int, int]:
"""
"""
Args:
Args:
a_shape:
a_shape:
...
...
spconv/csrc/hash/core.py
View file @
7af751dc
...
@@ -104,6 +104,8 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -104,6 +104,8 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
self
.
add_member
(
"map_8_8"
,
"tsl::robin_map<uint64_t, uint64_t>"
)
self
.
add_member
(
"map_8_8"
,
"tsl::robin_map<uint64_t, uint64_t>"
)
self
.
add_pybind_member
(
"insert_count_"
,
"int64_t"
,
prop_name
=
"insert_count"
,
readwrite
=
False
)
self
.
add_pybind_member
(
"insert_count_"
,
"int64_t"
,
prop_name
=
"insert_count"
,
readwrite
=
False
)
self
.
valid_hash_key_types
=
[
dtypes
.
int32
,
dtypes
.
int64
,
dtypes
.
uint32
,
dtypes
.
uint64
]
@
pccm
.
pybind
.
mark
@
pccm
.
pybind
.
mark
@
pccm
.
constructor
@
pccm
.
constructor
def
ctor
(
self
):
def
ctor
(
self
):
...
@@ -163,11 +165,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -163,11 +165,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
auto custream = reinterpret_cast<cudaStream_t>(stream);
auto custream = reinterpret_cast<cudaStream_t>(stream);
"""
)
"""
)
for
k_items
in
_dispatch
_ints
(
code
,
[
4
,
8
]
,
"keys_data.
itemsiz
e()"
):
for
k_items
in
_dispatch
(
code
,
self
.
valid_hash_key_types
,
"keys_data.
dtyp
e()"
):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
using K = tv::hash::itemsize_to_unsigned_t<
{
k_items
}
>;
using K =
{
k_items
}
;
constexpr K kEmptyKey = std::numeric_limits<K>::max();
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
"""
)
"""
)
for
v_items
in
_dispatch_ints
(
code
,
[
4
,
8
],
"values_data.itemsize()"
):
for
v_items
in
_dispatch_ints
(
code
,
[
4
,
8
],
"values_data.itemsize()"
):
...
@@ -176,10 +176,10 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -176,10 +176,10 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
V* value_data_ptr = reinterpret_cast<V*>(values_data.raw_data());
V* value_data_ptr = reinterpret_cast<V*>(values_data.raw_data());
using table_t =
using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
kEmptyKey
, false>;
tv::hash::default_empty_key_v<K>
, false>;
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
tv::cuda::Launch launcher(table.size(), custream);
tv::cuda::Launch launcher(table.size(), custream);
launcher(tv::hash::clear_
table
_split<table_t>, table);
launcher(tv::hash::clear_
map_kernel
_split<table_t>, table);
"""
)
"""
)
return
code
return
code
...
@@ -201,9 +201,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -201,9 +201,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
int64_t value_after_insert = keys.dim(0) + insert_count_;
int64_t value_after_insert = keys.dim(0) + insert_count_;
TV_ASSERT_RT_ERR(value_after_insert < keys_data.dim(0), "inserted count exceed maximum hash size");
TV_ASSERT_RT_ERR(value_after_insert < keys_data.dim(0), "inserted count exceed maximum hash size");
insert_count_ += keys.dim(0);
insert_count_ += keys.dim(0);
TV_ASSERT_RT_ERR(keys.dtype() == keys_data.dtype(), "keys dtype not equal to", keys_data.dtype());
}}
}}
auto N = keys.dim(0);
auto N = keys.dim(0);
TV_ASSERT_RT_ERR(keys.itemsize() == key_itemsize_, "keys itemsize not equal to", key_itemsize_);
if (!values.empty()){{
if (!values.empty()){{
TV_ASSERT_RT_ERR(values.itemsize() == value_itemsize_, "values itemsize not equal to", value_itemsize_);
TV_ASSERT_RT_ERR(values.itemsize() == value_itemsize_, "values itemsize not equal to", value_itemsize_);
TV_ASSERT_RT_ERR(keys.dim(0) == values.dim(0), "number of key and value must same");
TV_ASSERT_RT_ERR(keys.dim(0) == values.dim(0), "number of key and value must same");
...
@@ -231,10 +231,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -231,10 +231,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
auto custream = reinterpret_cast<cudaStream_t>(stream);
auto custream = reinterpret_cast<cudaStream_t>(stream);
"""
)
"""
)
for
k_items
in
_dispatch
_ints
(
code
,
[
4
,
8
]
,
"keys_data.
itemsiz
e()"
):
for
k_items
in
_dispatch
(
code
,
self
.
valid_hash_key_types
,
"keys_data.
dtyp
e()"
):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
using K = tv::hash::itemsize_to_unsigned_t<
{
k_items
}
>;
using K =
{
k_items
}
;
constexpr K kEmptyKey = std::numeric_limits<K>::max();
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
const K* key_ptr = reinterpret_cast<const K*>(keys.raw_data());
const K* key_ptr = reinterpret_cast<const K*>(keys.raw_data());
...
@@ -248,7 +247,7 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -248,7 +247,7 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
using table_t =
using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
kEmptyKey
, false>;
tv::hash::default_empty_key_v<K>
, false>;
tv::cuda::Launch launcher(N, custream);
tv::cuda::Launch launcher(N, custream);
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
launcher(tv::hash::insert_split<table_t>, table, key_ptr, value_ptr, size_t(N));
launcher(tv::hash::insert_split<table_t>, table, key_ptr, value_ptr, size_t(N));
...
@@ -279,6 +278,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -279,6 +278,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
TV_ASSERT_RT_ERR(values.itemsize() == value_itemsize_, "values itemsize not equal to", value_itemsize_);
TV_ASSERT_RT_ERR(values.itemsize() == value_itemsize_, "values itemsize not equal to", value_itemsize_);
TV_ASSERT_RT_ERR(N == values.dim(0) && is_empty.dim(0) == N, "number of key and value must same");
TV_ASSERT_RT_ERR(N == values.dim(0) && is_empty.dim(0) == N, "number of key and value must same");
auto is_empty_ptr = is_empty.data_ptr<uint8_t>();
auto is_empty_ptr = is_empty.data_ptr<uint8_t>();
if (!is_cpu){{
TV_ASSERT_RT_ERR(keys.dtype() == keys_data.dtype(), "keys dtype not equal to", keys_data.dtype());
}}
"""
)
"""
)
with
code
.
if_
(
"is_cpu"
):
with
code
.
if_
(
"is_cpu"
):
map_name
=
"cpu_map"
map_name
=
"cpu_map"
...
@@ -304,10 +306,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -304,10 +306,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
auto custream = reinterpret_cast<cudaStream_t>(stream);
auto custream = reinterpret_cast<cudaStream_t>(stream);
"""
)
"""
)
for
k_items
in
_dispatch
_ints
(
code
,
[
4
,
8
]
,
"keys_data.
itemsiz
e()"
):
for
k_items
in
_dispatch
(
code
,
self
.
valid_hash_key_types
,
"keys_data.
dtyp
e()"
):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
using K = tv::hash::itemsize_to_unsigned_t<
{
k_items
}
>;
using K =
{
k_items
}
;
constexpr K kEmptyKey = std::numeric_limits<K>::max();
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
K* key_ptr = reinterpret_cast<K*>(keys.raw_data());
K* key_ptr = reinterpret_cast<K*>(keys.raw_data());
...
@@ -319,7 +320,7 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -319,7 +320,7 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
V* value_ptr = reinterpret_cast<V*>(values.raw_data());
V* value_ptr = reinterpret_cast<V*>(values.raw_data());
using table_t =
using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
kEmptyKey
, false>;
tv::hash::default_empty_key_v<K>
, false>;
tv::cuda::Launch launcher(N, custream);
tv::cuda::Launch launcher(N, custream);
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
launcher(tv::hash::query_split<table_t>, table, key_ptr, value_ptr, is_empty_ptr, size_t(N));
launcher(tv::hash::query_split<table_t>, table, key_ptr, value_ptr, is_empty_ptr, size_t(N));
...
@@ -361,11 +362,12 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -361,11 +362,12 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
TV_ASSERT_RT_ERR(count.device() == 0, "count must be cuda");
TV_ASSERT_RT_ERR(count.device() == 0, "count must be cuda");
auto custream = reinterpret_cast<cudaStream_t>(stream);
auto custream = reinterpret_cast<cudaStream_t>(stream);
"""
)
"""
)
for
k_items
in
_dispatch
_ints
(
code
,
[
4
,
8
]
,
"keys_data.
itemsiz
e()"
):
for
k_items
in
_dispatch
(
code
,
self
.
valid_hash_key_types
,
"keys_data.
dtyp
e()"
):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
using K = tv::hash::itemsize_to_unsigned_t<
{
k_items
}
>;
using K =
{
k_items
}
;
constexpr K kEmptyKey = std::numeric_limits<K>::max();
using Kunsigned = tv::hash::itemsize_to_unsigned_t<sizeof(K)>;
auto count_ptr = count.data_ptr<K>();
auto count_ptr = count.data_ptr<Kunsigned>();
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
"""
)
"""
)
...
@@ -376,10 +378,10 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -376,10 +378,10 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
V* value_data_ptr = reinterpret_cast<V*>(values_data.raw_data());
V* value_data_ptr = reinterpret_cast<V*>(values_data.raw_data());
using table_t =
using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
kEmptyKey
, false>;
tv::hash::default_empty_key_v<K>
, false>;
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
tv::cuda::Launch launcher(table.size(), custream);
tv::cuda::Launch launcher(table.size(), custream);
launcher(tv::hash::assign_arange_split<table_t, K>, table, count_ptr);
launcher(tv::hash::assign_arange_split<table_t, K
unsigned
>, table, count_ptr);
"""
)
"""
)
else
:
else
:
code
.
raw
(
f
"""
code
.
raw
(
f
"""
...
@@ -426,7 +428,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -426,7 +428,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
TV_ASSERT_RT_ERR(keys.itemsize() == key_itemsize_, "keys itemsize not equal to", key_itemsize_);
TV_ASSERT_RT_ERR(keys.itemsize() == key_itemsize_, "keys itemsize not equal to", key_itemsize_);
TV_ASSERT_RT_ERR(values.itemsize() == value_itemsize_, "values itemsize not equal to", value_itemsize_);
TV_ASSERT_RT_ERR(values.itemsize() == value_itemsize_, "values itemsize not equal to", value_itemsize_);
TV_ASSERT_RT_ERR(N == values.dim(0), "number of key and value must same");
TV_ASSERT_RT_ERR(N == values.dim(0), "number of key and value must same");
if (!is_cpu){{
TV_ASSERT_RT_ERR(keys.dtype() == keys_data.dtype(), "keys dtype not equal to", keys_data.dtype());
}}
"""
)
"""
)
with
code
.
if_
(
"is_cpu"
):
with
code
.
if_
(
"is_cpu"
):
map_name
=
"cpu_map"
map_name
=
"cpu_map"
...
@@ -450,12 +454,12 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -450,12 +454,12 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
auto custream = reinterpret_cast<cudaStream_t>(stream);
auto custream = reinterpret_cast<cudaStream_t>(stream);
"""
)
"""
)
for
k_items
in
_dispatch
_ints
(
code
,
[
4
,
8
]
,
"keys_data.
itemsiz
e()"
):
for
k_items
in
_dispatch
(
code
,
self
.
valid_hash_key_types
,
"keys_data.
dtyp
e()"
):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
using K = tv::hash::itemsize_to_unsigned_t<
{
k_items
}
>;
using K =
{
k_items
}
;
auto count_ptr = count.data_ptr<K>();
using Kunsigned = tv::hash::itemsize_to_unsigned_t<sizeof(K)>;
auto count_ptr = count.data_ptr<Kunsigned>();
constexpr K kEmptyKey = std::numeric_limits<K>::max();
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
K* key_ptr = reinterpret_cast<K*>(keys.raw_data());
K* key_ptr = reinterpret_cast<K*>(keys.raw_data());
...
@@ -467,10 +471,10 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -467,10 +471,10 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
V* value_ptr = reinterpret_cast<V*>(values.raw_data());
V* value_ptr = reinterpret_cast<V*>(values.raw_data());
using table_t =
using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
kEmptyKey
, false>;
tv::hash::default_empty_key_v<K>
, false>;
tv::cuda::Launch launcher(N, custream);
tv::cuda::Launch launcher(N, custream);
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
launcher(tv::hash::iterate_table_split<table_t, K>, table, key_ptr, value_ptr, size_t(N), count_ptr);
launcher(tv::hash::iterate_table_split<table_t, K
unsigned
>, table, key_ptr, value_ptr, size_t(N), count_ptr);
"""
)
"""
)
else
:
else
:
code
.
raw
(
f
"""
code
.
raw
(
f
"""
...
@@ -523,10 +527,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -523,10 +527,9 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
auto custream = reinterpret_cast<cudaStream_t>(stream);
auto custream = reinterpret_cast<cudaStream_t>(stream);
"""
)
"""
)
for
k_items
in
_dispatch
_ints
(
code
,
[
4
,
8
]
,
"keys_data.
itemsiz
e()"
):
for
k_items
in
_dispatch
(
code
,
self
.
valid_hash_key_types
,
"keys_data.
dtyp
e()"
):
code
.
raw
(
f
"""
code
.
raw
(
f
"""
using K = tv::hash::itemsize_to_unsigned_t<
{
k_items
}
>;
using K =
{
k_items
}
;
constexpr K kEmptyKey = std::numeric_limits<K>::max();
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
K* key_data_ptr = reinterpret_cast<K*>(keys_data.raw_data());
const K* key_ptr = reinterpret_cast<const K*>(keys.raw_data());
const K* key_ptr = reinterpret_cast<const K*>(keys.raw_data());
...
@@ -538,7 +541,7 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
...
@@ -538,7 +541,7 @@ class HashTable(pccm.Class, pccm.pybind.PybindClassMixin):
const V* value_ptr = reinterpret_cast<const V*>(values.raw_data());
const V* value_ptr = reinterpret_cast<const V*>(values.raw_data());
using table_t =
using table_t =
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
tv::hash::LinearHashTableSplit<K, V, tv::hash::Murmur3Hash<K>,
kEmptyKey
, false>;
tv::hash::default_empty_key_v<K>
, false>;
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
table_t table(key_data_ptr, value_data_ptr, keys_data.dim(0));
tv::cuda::Launch launcher(N, custream);
tv::cuda::Launch launcher(N, custream);
launcher(insert_exist_keys_kernel<table_t>, table, key_ptr, value_ptr, is_empty_ptr, size_t(N));
launcher(insert_exist_keys_kernel<table_t>, table, key_ptr, value_ptr, is_empty_ptr, size_t(N));
...
...
spconv/csrc/sparse/all.py
View file @
7af751dc
This diff is collapsed.
Click to expand it.
spconv/csrc/sparse/alloc.py
0 → 100644
View file @
7af751dc
import
pccm
from
cumm.common
import
TensorView
,
TensorViewCPU
,
TensorViewKernel
,
ThrustLib
class
ExternalAllocatorGuard
(
pccm
.
Class
):
def
__init__
(
self
):
super
().
__init__
()
self
.
add_dependency
(
TensorView
)
self
.
add_member
(
"tensor"
,
"tv::Tensor"
)
self
.
add_member
(
"free_func"
,
"std::function<void(tv::Tensor)>"
)
@
pccm
.
constructor
def
ctor
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"ten"
,
"tv::Tensor"
)
code
.
arg
(
"free_func"
,
"std::function<void(tv::Tensor)>"
)
code
.
ctor_init
(
"tensor"
,
"ten"
)
code
.
ctor_init
(
"free_func"
,
"free_func"
)
return
code
@
pccm
.
constructor
def
dctor
(
self
):
code
=
pccm
.
code
()
return
code
@
pccm
.
destructor
def
dtor
(
self
):
code
=
pccm
.
code
()
code
.
raw
(
f
"""
if (!tensor.empty() && free_func){{
free_func(tensor);
}}
"""
)
return
code
class
ExternalAllocator
(
pccm
.
Class
):
def
__init__
(
self
):
super
().
__init__
()
self
.
add_dependency
(
TensorView
,
ExternalAllocatorGuard
)
self
.
use_shared
=
True
self
.
ptr_type
=
"unique"
if
self
.
use_shared
:
self
.
ptr_type
=
"shared"
self
.
add_typedef
(
"guard_t"
,
f
"std::
{
self
.
ptr_type
}
_ptr<ExternalAllocatorGuard>"
)
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
@
pccm
.
member_function
(
virtual
=
True
,
pure_virtual
=
True
)
def
zeros
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"name"
,
"std::string"
)
code
.
arg
(
"shape"
,
"std::vector<int64_t>"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
@
pccm
.
member_function
(
virtual
=
True
,
pure_virtual
=
True
)
def
empty
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"name"
,
"std::string"
)
code
.
arg
(
"shape"
,
"std::vector<int64_t>"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
@
pccm
.
member_function
(
virtual
=
True
,
pure_virtual
=
True
)
def
full_int
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"name"
,
"std::string"
)
code
.
arg
(
"shape"
,
"std::vector<int64_t>"
)
code
.
arg
(
"value"
,
"int"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
@
pccm
.
member_function
(
virtual
=
True
,
pure_virtual
=
True
)
def
full_float
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"name"
,
"std::string"
)
code
.
arg
(
"shape"
,
"std::vector<int64_t>"
)
code
.
arg
(
"value"
,
"float"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
return
code
.
ret
(
"tv::Tensor"
)
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
@
pccm
.
member_function
(
virtual
=
True
,
pure_virtual
=
True
)
def
free
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"ten"
,
"tv::Tensor"
)
return
code
@
pccm
.
pybind
.
mark
(
virtual
=
True
)
@
pccm
.
member_function
(
virtual
=
True
,
pure_virtual
=
True
)
def
free_noexcept
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"ten"
,
"tv::Tensor"
)
return
code
@
pccm
.
member_function
def
zeros_guard
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"shape"
,
"std::vector<int64_t>"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
code
.
raw
(
f
"""
// "" means temp memory
auto ten = zeros("", shape, dtype, device);
return std::make_
{
self
.
ptr_type
}
<ExternalAllocatorGuard>(ten, [this](tv::Tensor ten){{
this->free(ten);
}});
"""
)
return
code
.
ret
(
f
"std::
{
self
.
ptr_type
}
_ptr<ExternalAllocatorGuard>"
)
@
pccm
.
member_function
def
empty_guard
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"shape"
,
"std::vector<int64_t>"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
code
.
raw
(
f
"""
auto ten = empty("", shape, dtype, device);
return std::make_
{
self
.
ptr_type
}
<ExternalAllocatorGuard>(ten, [this](tv::Tensor ten){{
this->free(ten);
}});
"""
)
return
code
.
ret
(
f
"std::
{
self
.
ptr_type
}
_ptr<ExternalAllocatorGuard>"
)
@
pccm
.
member_function
def
full_int_guard
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"shape"
,
"std::vector<int64_t>"
)
code
.
arg
(
"value"
,
"int"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
code
.
raw
(
f
"""
auto ten = full_int("", shape, value, dtype, device);
return std::make_
{
self
.
ptr_type
}
<ExternalAllocatorGuard>(ten, [this](tv::Tensor ten){{
this->free(ten);
}});
"""
)
return
code
.
ret
(
f
"std::
{
self
.
ptr_type
}
_ptr<ExternalAllocatorGuard>"
)
@
pccm
.
member_function
def
full_float_guard
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"shape"
,
"std::vector<int64_t>"
)
code
.
arg
(
"value"
,
"int"
)
code
.
arg
(
"dtype"
,
"int"
)
code
.
arg
(
"device"
,
"int"
)
code
.
raw
(
f
"""
auto ten = full_float("", shape, value, dtype, device);
return std::make_
{
self
.
ptr_type
}
<ExternalAllocatorGuard>(ten, [this](tv::Tensor t){{
this->free(t);
}});
"""
)
return
code
.
ret
(
f
"std::
{
self
.
ptr_type
}
_ptr<ExternalAllocatorGuard>"
)
class
ThrustAllocator
(
pccm
.
Class
):
def
__init__
(
self
):
super
().
__init__
()
self
.
add_dependency
(
TensorView
,
ExternalAllocator
)
self
.
add_include
(
"functional"
,
"memory"
)
self
.
add_member
(
"allocator_"
,
"ExternalAllocator&"
,)
self
.
add_typedef
(
"value_type"
,
"char"
)
@
pccm
.
constructor
def
ctor
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"allocator"
,
"ExternalAllocator&"
)
code
.
ctor_init
(
"allocator_"
,
"allocator"
)
return
code
@
pccm
.
member_function
def
allocate
(
self
):
code
=
pccm
.
FunctionCode
()
code
.
arg
(
"num_bytes"
,
"std::ptrdiff_t"
)
code
.
ret
(
"char*"
)
code
.
raw
(
f
"""
auto ten = allocator_.empty("", {{num_bytes}}, tv::uint8, 0);
return reinterpret_cast<char*>(ten.raw_data());
"""
)
return
code
@
pccm
.
member_function
def
deallocate
(
self
):
code
=
pccm
.
FunctionCode
()
code
.
arg
(
"ptr"
,
"char *"
)
code
.
arg
(
"num_bytes"
,
"size_t"
)
code
.
raw
(
f
"""
return allocator_.free_noexcept(tv::from_blob(ptr, {{num_bytes}}, tv::uint8, 0));
"""
)
return
code
spconv/csrc/sparse/convops.py
0 → 100644
View file @
7af751dc
import
pccm
from
cumm.gemm.main
import
GemmMainUnitTest
from
cumm.conv.main
import
ConvMainUnitTest
from
.alloc
import
ExternalAllocator
from
spconv.core
import
ConvAlgo
from
spconv.constants
import
SpconvAllocatorKeys
from
cumm.constants
import
CUMM_CPU_ONLY_BUILD
from
cumm.common
import
GemmBasicHost
,
TensorView
,
NlohmannJson
class
GemmTuneResult
(
pccm
.
Class
,
pccm
.
pybind
.
PybindClassMixin
):
def
__init__
(
self
):
super
().
__init__
()
self
.
add_dependency
(
GemmBasicHost
,
TensorView
)
self
.
add_pybind_member
(
"algo_desp"
,
"tv::gemm::GemmAlgoDesp"
)
self
.
add_pybind_member
(
"arch"
,
"std::tuple<int, int>"
)
self
.
add_pybind_member
(
"splitk"
,
"int"
)
@
pccm
.
pybind
.
mark
@
pccm
.
member_function
def
is_valid
(
self
):
code
=
pccm
.
code
()
code
.
raw
(
f
"return splitk > 0 && std::get<0>(arch) > 0"
)
return
code
@
pccm
.
pybind
.
mark
@
pccm
.
constructor
def
defaultctor
(
self
):
code
=
pccm
.
code
()
code
.
ctor_init
(
"algo_desp"
,
"tv::gemm::GemmAlgoDesp()"
)
code
.
ctor_init
(
"arch"
,
"std::make_tuple(-1, -1)"
)
code
.
ctor_init
(
"splitk"
,
"-1"
)
return
code
@
pccm
.
pybind
.
mark
@
pccm
.
constructor
def
ctor
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"algo_desp"
,
"tv::gemm::GemmAlgoDesp"
,
pyanno
=
"cumm.tensorview.gemm.GemmAlgoDesp"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"splitk"
,
"int"
)
code
.
ctor_init
(
"algo_desp"
,
"algo_desp"
)
code
.
ctor_init
(
"arch"
,
"arch"
)
code
.
ctor_init
(
"splitk"
,
"splitk"
)
return
code
class
ConvTuneResult
(
pccm
.
Class
,
pccm
.
pybind
.
PybindClassMixin
):
def
__init__
(
self
):
super
().
__init__
()
self
.
add_dependency
(
GemmBasicHost
,
TensorView
)
self
.
add_pybind_member
(
"algo_desp"
,
"tv::gemm::ConvAlgoDesp"
)
self
.
add_pybind_member
(
"arch"
,
"std::tuple<int, int>"
)
self
.
add_pybind_member
(
"splitk"
,
"int"
)
@
pccm
.
pybind
.
mark
@
pccm
.
constructor
def
defaultctor
(
self
):
code
=
pccm
.
code
()
code
.
ctor_init
(
"algo_desp"
,
"tv::gemm::ConvAlgoDesp()"
)
code
.
ctor_init
(
"arch"
,
"std::make_tuple(-1, -1)"
)
code
.
ctor_init
(
"splitk"
,
"-1"
)
return
code
@
pccm
.
pybind
.
mark
@
pccm
.
constructor
def
ctor
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"algo_desp"
,
"tv::gemm::ConvAlgoDesp"
,
pyanno
=
"cumm.tensorview.gemm.ConvAlgoDesp"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"splitk"
,
"int"
)
code
.
ctor_init
(
"algo_desp"
,
"algo_desp"
)
code
.
ctor_init
(
"arch"
,
"arch"
)
code
.
ctor_init
(
"splitk"
,
"splitk"
)
return
code
@
pccm
.
pybind
.
mark
@
pccm
.
member_function
def
is_valid
(
self
):
code
=
pccm
.
code
()
code
.
raw
(
f
"return splitk > 0 && std::get<0>(arch) > 0"
)
return
code
class
GemmTunerSimple
(
pccm
.
ParameterizedClass
):
def
__init__
(
self
,
gemm_cu
:
GemmMainUnitTest
,
conv_cu
:
ConvMainUnitTest
):
super
().
__init__
()
self
.
add_dependency
(
ExternalAllocator
,
GemmTuneResult
,
ConvTuneResult
,
TensorView
)
self
.
add_param_class
(
"gemm"
,
gemm_cu
,
"GemmMain"
)
self
.
add_param_class
(
"conv"
,
conv_cu
,
"ConvMain"
)
self
.
add_include
(
"tensorview/utility/tuplehash.h"
)
self
.
add_member
(
"desps_"
,
"std::vector<tv::gemm::GemmAlgoDesp>"
)
self
.
add_member
(
"nvrtc_progs_"
,
"std::unordered_map<std::string, tv::NVRTCProgram>"
)
self
.
add_member
(
"nvrtc_caches_"
,
"std::unordered_map<std::tuple<std::string, int, int, std::uintptr_t>, tv::NVRTCModule>"
)
@
pccm
.
pybind
.
mark
@
pccm
.
constructor
def
ctor
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"desps"
,
"std::vector<tv::gemm::GemmAlgoDesp>"
)
code
.
arg
(
"nvrtc_progs"
,
"std::unordered_map<std::string, std::string>"
)
code
.
ctor_init
(
"desps_"
,
"desps"
)
code
.
raw
(
f
"""
for (auto& v : nvrtc_progs){{
const uint8_t* code_ptr = reinterpret_cast<const uint8_t*>(v.second.c_str());
nvrtc_progs_.insert(v.first, tv::NVRTCProgram::from_binary(code_ptr, v.second.size()));
}}
"""
)
return
code
@
pccm
.
member_function
def
get_all_available
(
self
):
code
=
pccm
.
code
()
code
.
arg
(
"a, b, c"
,
"tv::Tensor"
)
code
.
arg
(
"trans_a, trans_b, trans_c"
,
"bool"
)
code
.
arg
(
"arch"
,
"std::tuple<int, int>"
)
code
.
arg
(
"nvrtc_progs"
,
"std::unordered_map<std::string, std::string>"
)
code
.
ctor_init
(
"desps_"
,
"desps"
)
code
.
raw
(
f
"""
for (auto& v : nvrtc_progs){{
const uint8_t* code_ptr = reinterpret_cast<const uint8_t*>(v.second.c_str());
nvrtc_progs_.insert(v.first, tv::NVRTCProgram::from_binary(code_ptr, v.second.size()));
}}
"""
)
return
code
class
ConvGemmOps
(
pccm
.
ParameterizedClass
):
def
__init__
(
self
,
gemm_cu
:
GemmMainUnitTest
,
conv_cu
:
ConvMainUnitTest
):
super
().
__init__
()
self
.
add_dependency
(
ExternalAllocator
,
GemmTuneResult
,
ConvTuneResult
)
self
.
add_param_class
(
"gemm"
,
gemm_cu
,
"GemmMain"
)
self
.
add_param_class
(
"conv"
,
conv_cu
,
"ConvMain"
)
@
pccm
.
pybind
.
mark
@
pccm
.
static_function
def
indice_conv
(
self
):
"""1. this function need to take a out features
that from subm first mm.
2. this function don't support CPU.
"""
code
=
pccm
.
code
()
code
.
arg
(
"allocator"
,
"ExternalAllocator&"
)
code
.
arg
(
"out_features_after_mm"
,
"tv::Tensor"
)
code
.
arg
(
"features, filters, indice_pairs"
,
"tv::Tensor"
)
code
.
arg
(
"indice_pair_num"
,
"tv::Tensor"
)
code
.
arg
(
"num_activate_out"
,
"int"
)
code
.
arg
(
"inverse"
,
"bool"
,
"false"
)
code
.
arg
(
"subm"
,
"bool"
,
"false"
)
code
.
arg
(
"algo"
,
"int"
,
f
"
{
ConvAlgo
.
Native
.
value
}
"
)
code
.
arg
(
"filter_hwio"
,
"bool"
,
"false"
)
code
.
arg
(
"stream_int"
,
f
"std::uintptr_t"
,
"0"
,
pyanno
=
"int"
)
if
CUMM_CPU_ONLY_BUILD
:
code
.
raw
(
f
"""
throw std::runtime_error("this function can only be used with CUDA.")
"""
)
return
code
.
ret
(
"tv::Tensor"
)
code
.
raw
(
f
"""
TV_ASSERT_RT_ERR(!features.is_cpu(), "this function don't support cpu.")
int out_channel;
if (filter_hwio){{
out_channel = filters.dim(-1);
}}else{{
out_channel = filters.dim(-2);
}}
filters = filters.view(-1, filters.dim(-2), filters.dim(-1));
int kv = filters.dim(0);
int kv_center = kv / 2;
tv::Tensor out_features;
if (kv == 1 && subm){{
return;
}}
auto indice_pair_num_cpu = indice_pair_num.cpu();
auto indice_pair_num_cpu_ptr = indice_pair_num_cpu.data_ptr<int>();
int maxnhot = 0;
bool all_zero = true;
for (int i = 0; i < kv; ++i){{
if (indice_pair_num_cpu_ptr[i] != 0){{
all_zero = false;
maxnhot = std::max(maxnhot, indice_pair_num_cpu_ptr[i]);
}}
}}
if (subm && all_zero){{
return;
}}
bool inited = subm;
auto a = features;
auto c = out_features;
auto pair_in = indice_pairs[int(inverse)];
auto pair_out = indice_pairs[int(!inverse)];
"""
)
return
code
spconv/csrc/sparse/cpu_core.py
View file @
7af751dc
...
@@ -23,13 +23,8 @@ class OMPLib(pccm.Class):
...
@@ -23,13 +23,8 @@ class OMPLib(pccm.Class):
self
.
add_dependency
(
TensorView
)
self
.
add_dependency
(
TensorView
)
self
.
add_include
(
"tensorview/parallel/all.h"
)
self
.
add_include
(
"tensorview/parallel/all.h"
)
if
compat
.
InWindows
:
if
compat
.
InWindows
:
self
.
build_meta
.
add_cflags
(
"cl"
,
"/openmp"
)
self
.
build_meta
.
add_
public_
cflags
(
"cl"
,
"/openmp"
)
else
:
else
:
self
.
build_meta
.
add_cflags
(
"g++"
,
"-fopenmp"
)
self
.
build_meta
.
add_public_cflags
(
"g++"
,
"-fopenmp"
)
self
.
build_meta
.
add_cflags
(
"clang++"
,
"-fopenmp"
)
self
.
build_meta
.
add_public_cflags
(
"clang++"
,
"-fopenmp"
)
if
"g++"
not
in
self
.
build_meta
.
compiler_to_ldflags
:
self
.
build_meta
.
add_ldflags
(
"g++,clang++"
,
"-fopenmp"
)
self
.
build_meta
.
compiler_to_ldflags
[
"g++"
]
=
[]
self
.
build_meta
.
compiler_to_ldflags
[
"g++"
].
extend
([
"-fopenmp"
])
if
"clang++"
not
in
self
.
build_meta
.
compiler_to_ldflags
:
self
.
build_meta
.
compiler_to_ldflags
[
"clang++"
]
=
[]
self
.
build_meta
.
compiler_to_ldflags
[
"clang++"
].
extend
([
"-fopenmp"
])
spconv/csrc/sparse/indices.py
View file @
7af751dc
This diff is collapsed.
Click to expand it.
spconv/csrc/sparse/maxpool.py
View file @
7af751dc
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
# limitations under the License.
# limitations under the License.
import
contextlib
import
contextlib
from
cumm.conv.bases
import
ConvEnum
from
cumm.gemm.core.metaarray
import
MetaArray
,
seq
from
cumm.gemm.core.metaarray
import
MetaArray
,
seq
from
cumm
import
dtypes
from
cumm
import
dtypes
import
pccm
import
pccm
...
@@ -202,14 +201,14 @@ class IndiceMaxPool(pccm.Class):
...
@@ -202,14 +201,14 @@ class IndiceMaxPool(pccm.Class):
// if a value is found, other value won't be executed.
// if a value is found, other value won't be executed.
int NumFeatures = TV_DECLTYPE(V)::value;
int NumFeatures = TV_DECLTYPE(V)::value;
int Num0 = MaxThreads / NumFeatures;
int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0));
dim3 blocks(tv::div_up(out.dim(1),
int64_t(
NumFeatures)
)
, tv::div_up(nhot,
int64_t(
Num0))
)
;
dim3 threads(NumFeatures, Num0);
dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
}});
}});
if (!found){{
if (!found){{
int NumFeatures = 16;
int NumFeatures = 16;
int Num0 = MaxThreads / NumFeatures;
int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0));
dim3 blocks(tv::div_up(out.dim(1),
int64_t(
NumFeatures)
)
, tv::div_up(nhot,
int64_t(
Num0))
)
;
dim3 threads(NumFeatures, Num0);
dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
}}
}}
...
@@ -244,14 +243,14 @@ class IndiceMaxPool(pccm.Class):
...
@@ -244,14 +243,14 @@ class IndiceMaxPool(pccm.Class):
// if a value is found, other value won't be executed.
// if a value is found, other value won't be executed.
int NumFeatures = TV_DECLTYPE(V)::value;
int NumFeatures = TV_DECLTYPE(V)::value;
int Num0 = MaxThreads / NumFeatures;
int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0));
dim3 blocks(tv::div_up(out.dim(1),
int64_t(
NumFeatures)
)
, tv::div_up(nhot,
int64_t(
Num0))
)
;
dim3 threads(NumFeatures, Num0);
dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
}});
}});
if (!found){{
if (!found){{
int NumFeatures = 16;
int NumFeatures = 16;
int Num0 = MaxThreads / NumFeatures;
int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0));
dim3 blocks(tv::div_up(out.dim(1),
int64_t(
NumFeatures)
)
, tv::div_up(nhot,
int64_t(
Num0))
)
;
dim3 threads(NumFeatures, Num0);
dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
}}
}}
...
@@ -287,14 +286,14 @@ class IndiceMaxPool(pccm.Class):
...
@@ -287,14 +286,14 @@ class IndiceMaxPool(pccm.Class):
// if a value is found, other value won't be executed.
// if a value is found, other value won't be executed.
int NumFeatures = TV_DECLTYPE(V)::value;
int NumFeatures = TV_DECLTYPE(V)::value;
int Num0 = MaxThreads / NumFeatures;
int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0));
dim3 blocks(tv::div_up(out.dim(1),
int64_t(
NumFeatures)
)
, tv::div_up(nhot,
int64_t(
Num0))
)
;
dim3 threads(NumFeatures, Num0);
dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
}});
}});
if (!found){{
if (!found){{
int NumFeatures = 16;
int NumFeatures = 16;
int Num0 = MaxThreads / NumFeatures;
int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0));
dim3 blocks(tv::div_up(out.dim(1),
int64_t(
NumFeatures)
)
, tv::div_up(nhot,
int64_t(
Num0))
)
;
dim3 threads(NumFeatures, Num0);
dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
}}
}}
...
@@ -331,14 +330,14 @@ class IndiceMaxPool(pccm.Class):
...
@@ -331,14 +330,14 @@ class IndiceMaxPool(pccm.Class):
// if a value is found, other value won't be executed.
// if a value is found, other value won't be executed.
int NumFeatures = TV_DECLTYPE(V)::value;
int NumFeatures = TV_DECLTYPE(V)::value;
int Num0 = MaxThreads / NumFeatures;
int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0));
dim3 blocks(tv::div_up(out.dim(1),
int64_t(
NumFeatures)
)
, tv::div_up(nhot,
int64_t(
Num0))
)
;
dim3 threads(NumFeatures, Num0);
dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
}});
}});
if (!found){{
if (!found){{
int NumFeatures = 16;
int NumFeatures = 16;
int Num0 = MaxThreads / NumFeatures;
int Num0 = MaxThreads / NumFeatures;
dim3 blocks(tv::div_up(out.dim(1), NumFeatures), tv::div_up(nhot, Num0));
dim3 blocks(tv::div_up(out.dim(1),
int64_t(
NumFeatures)
)
, tv::div_up(nhot,
int64_t(
Num0))
)
;
dim3 threads(NumFeatures, Num0);
dim3 threads(NumFeatures, Num0);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
launcher = tv::cuda::Launch(blocks, threads, cudastream);
}}
}}
...
...
spconv/csrc/sparse/pointops.py
View file @
7af751dc
...
@@ -126,6 +126,7 @@ class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
...
@@ -126,6 +126,7 @@ class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
super
().
__init__
()
super
().
__init__
()
self
.
add_dependency
(
TensorView
,
TensorViewHashKernel
)
self
.
add_dependency
(
TensorView
,
TensorViewHashKernel
)
self
.
add_param_class
(
"layout_ns"
,
layout
,
"Layout"
)
self
.
add_param_class
(
"layout_ns"
,
layout
,
"Layout"
)
self
.
add_include
(
"tensorview/hash/ops.h"
)
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
ndim
=
ndim
self
.
ndim
=
ndim
self
.
zyx
=
zyx
self
.
zyx
=
zyx
...
@@ -447,7 +448,7 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
...
@@ -447,7 +448,7 @@ class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin):
TV_ASSERT_RT_ERR(point_indice_data.dim(0) >= points.dim(0), "point_indice_data too small")
TV_ASSERT_RT_ERR(point_indice_data.dim(0) >= points.dim(0), "point_indice_data too small")
num_per_voxel.zero_(ctx);
num_per_voxel.zero_(ctx);
table_t hash = table_t(hashdata.data_ptr<pair_t>(), expected_hash_data_num);
table_t hash = table_t(hashdata.data_ptr<pair_t>(), expected_hash_data_num);
hash
.
clear
(
custream);
tv::
hash
::
clear
_map(hash,
custream);
auto launcher = tv::cuda::Launch(points.dim(0), custream);
auto launcher = tv::cuda::Launch(points.dim(0), custream);
launcher(kernel::build_hash_table<table_t>, hash, points.data_ptr<const
{
self
.
dtype
}
>(),
launcher(kernel::build_hash_table<table_t>, hash, points.data_ptr<const
{
self
.
dtype
}
>(),
point_indice_data.data_ptr<int64_t>(),
point_indice_data.data_ptr<int64_t>(),
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment