Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
3b545945
Commit
3b545945
authored
May 09, 2026
by
one
Browse files
WIP
parent
263d6b47
Pipeline
#3583
canceled with stages
Changes
5
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
161 additions
and
3 deletions
+161
-3
spconv/constants.py
spconv/constants.py
+1
-1
spconv/core.py
spconv/core.py
+147
-0
spconv/csrc/sparse/all.py
spconv/csrc/sparse/all.py
+1
-1
spconv/csrc/sparse/indices.py
spconv/csrc/sparse/indices.py
+5
-0
spconv/pytorch/cppcore.py
spconv/pytorch/cppcore.py
+7
-1
No files found.
spconv/constants.py
View file @
3b545945
...
...
@@ -60,7 +60,7 @@ SPCONV_BWD_SPLITK = list(map(int, os.getenv("SPCONV_BWD_SPLITK", "1,2,4,8,16,32,
SPCONV_NVRTC_MODE
=
NVRTCMode
.
ConstantMemory
SPCONV_DEBUG_NVRTC_KERNELS
=
False
SPCONV_DEBUG_CPP_ONLY
=
project_is_editable
(
PACKAGE_NAME
)
SPCONV_DEBUG_CPP_ONLY
=
EDITABLE_INSTALLED
class
AllocKeys
:
...
...
spconv/core.py
View file @
3b545945
...
...
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
enum
import
Enum
from
cumm.gemm.main
import
gen_shuffle_params_v2
as
gen_shuffle_params
,
GemmAlgoParams
from
cumm.gemm
import
kernel
...
...
@@ -1338,6 +1339,152 @@ if not SPCONV_INT8_DEBUG:
int8_inference
=
True
),
])
def
_dtype_shortcuts
(
param
):
return
tuple
(
getattr
(
param
,
name
).
shortcut
()
for
name
in
(
"dtype_a"
,
"dtype_b"
,
"dtype_c"
)
)
def
_all_dtypes_are
(
param
,
dtype_shortcut
:
str
):
return
all
(
shortcut
==
dtype_shortcut
for
shortcut
in
_dtype_shortcuts
(
param
))
def
_has_any_dtype
(
param
,
dtype_shortcuts
):
return
any
(
shortcut
in
dtype_shortcuts
for
shortcut
in
_dtype_shortcuts
(
param
))
def
_is_fp32_simt_param
(
param
):
return
param
.
algo
==
GemmAlgo
.
Simt
and
_all_dtypes_are
(
param
,
"f32"
)
def
_is_ampere_param
(
param
):
return
param
.
algo
==
GemmAlgo
.
Ampere
def
_is_static_param
(
param
):
return
not
getattr
(
param
,
"is_nvrtc"
,
False
)
def
_is_non_int8_param
(
param
):
return
not
getattr
(
param
,
"int8_inference"
,
False
)
def
_is_fp32_ampere_param
(
param
):
return
(
_is_ampere_param
(
param
)
and
_is_static_param
(
param
)
and
_is_non_int8_param
(
param
)
and
_all_dtypes_are
(
param
,
"f32"
)
)
def
_is_f16_ampere_param
(
param
):
return
(
_is_ampere_param
(
param
)
and
_is_static_param
(
param
)
and
_is_non_int8_param
(
param
)
and
_has_any_dtype
(
param
,
{
"f16"
})
)
def
_is_ampere_no_int8_static_param
(
param
):
return
_is_ampere_param
(
param
)
and
_is_static_param
(
param
)
and
_is_non_int8_param
(
param
)
def
_is_ampere_int8_param
(
param
):
return
_is_ampere_param
(
param
)
and
getattr
(
param
,
"int8_inference"
,
False
)
def
_is_non_int8_nvrtc_param
(
param
):
return
getattr
(
param
,
"is_nvrtc"
,
False
)
and
_is_non_int8_param
(
param
)
def
_is_fp8_param
(
param
):
return
_has_any_dtype
(
param
,
{
"e4m3"
,
"e5m2"
})
def
_filter_params
(
params
,
predicate
):
return
[
param
for
param
in
params
if
predicate
(
param
)]
def
_clear_turing_volta
():
global
SHUFFLE_TURING_PARAMS
,
SHUFFLE_VOLTA_PARAMS
global
IMPLGEMM_TURING_PARAMS
,
IMPLGEMM_VOLTA_PARAMS
SHUFFLE_TURING_PARAMS
=
[]
SHUFFLE_VOLTA_PARAMS
=
[]
IMPLGEMM_TURING_PARAMS
=
[]
IMPLGEMM_VOLTA_PARAMS
=
[]
_DTK_KERNEL_FILTER
=
os
.
getenv
(
"SPCONV_DTK_KERNEL_FILTER"
,
""
).
lower
()
if
_DTK_KERNEL_FILTER
==
"dtk_smoke"
:
SHUFFLE_SIMT_PARAMS
=
_filter_params
(
SHUFFLE_SIMT_PARAMS
,
_is_fp32_simt_param
)[:
4
]
SHUFFLE_AMPERE_PARAMS
=
[]
IMPLGEMM_SIMT_PARAMS
=
_filter_params
(
IMPLGEMM_SIMT_PARAMS
,
_is_fp32_simt_param
)[:
4
]
IMPLGEMM_AMPERE_PARAMS
=
[]
_clear_turing_volta
()
elif
_DTK_KERNEL_FILTER
==
"dtk_fp32_simt"
:
SHUFFLE_SIMT_PARAMS
=
_filter_params
(
SHUFFLE_SIMT_PARAMS
,
_is_fp32_simt_param
)
SHUFFLE_AMPERE_PARAMS
=
[]
IMPLGEMM_SIMT_PARAMS
=
_filter_params
(
IMPLGEMM_SIMT_PARAMS
,
_is_fp32_simt_param
)
IMPLGEMM_AMPERE_PARAMS
=
[]
_clear_turing_volta
()
elif
_DTK_KERNEL_FILTER
==
"dtk_all_simt"
:
SHUFFLE_AMPERE_PARAMS
=
[]
IMPLGEMM_AMPERE_PARAMS
=
[]
_clear_turing_volta
()
elif
_DTK_KERNEL_FILTER
==
"dtk_fp32_ampere"
:
SHUFFLE_SIMT_PARAMS
=
_filter_params
(
SHUFFLE_SIMT_PARAMS
,
_is_fp32_simt_param
)
SHUFFLE_AMPERE_PARAMS
=
_filter_params
(
SHUFFLE_AMPERE_PARAMS
,
_is_fp32_ampere_param
)
IMPLGEMM_SIMT_PARAMS
=
_filter_params
(
IMPLGEMM_SIMT_PARAMS
,
_is_fp32_simt_param
)
IMPLGEMM_AMPERE_PARAMS
=
_filter_params
(
IMPLGEMM_AMPERE_PARAMS
,
_is_fp32_ampere_param
)
_clear_turing_volta
()
elif
_DTK_KERNEL_FILTER
==
"dtk_f16_ampere"
:
SHUFFLE_SIMT_PARAMS
=
_filter_params
(
SHUFFLE_SIMT_PARAMS
,
_is_fp32_simt_param
)
SHUFFLE_AMPERE_PARAMS
=
_filter_params
(
SHUFFLE_AMPERE_PARAMS
,
_is_f16_ampere_param
)
IMPLGEMM_SIMT_PARAMS
=
_filter_params
(
IMPLGEMM_SIMT_PARAMS
,
_is_fp32_simt_param
)
IMPLGEMM_AMPERE_PARAMS
=
_filter_params
(
IMPLGEMM_AMPERE_PARAMS
,
_is_f16_ampere_param
)
_clear_turing_volta
()
elif
_DTK_KERNEL_FILTER
==
"dtk_ampere_no_int8"
:
SHUFFLE_AMPERE_PARAMS
=
_filter_params
(
SHUFFLE_AMPERE_PARAMS
,
_is_ampere_no_int8_static_param
)
IMPLGEMM_AMPERE_PARAMS
=
_filter_params
(
IMPLGEMM_AMPERE_PARAMS
,
_is_ampere_no_int8_static_param
)
_clear_turing_volta
()
elif
_DTK_KERNEL_FILTER
==
"dtk_int8_ampere"
:
SHUFFLE_AMPERE_PARAMS
=
_filter_params
(
SHUFFLE_AMPERE_PARAMS
,
lambda
param
:
_is_ampere_no_int8_static_param
(
param
)
or
_is_ampere_int8_param
(
param
))
IMPLGEMM_AMPERE_PARAMS
=
_filter_params
(
IMPLGEMM_AMPERE_PARAMS
,
lambda
param
:
_is_ampere_no_int8_static_param
(
param
)
or
_is_ampere_int8_param
(
param
))
_clear_turing_volta
()
elif
_DTK_KERNEL_FILTER
==
"dtk_nvrtc"
:
SHUFFLE_AMPERE_PARAMS
=
_filter_params
(
SHUFFLE_AMPERE_PARAMS
,
lambda
param
:
_is_ampere_no_int8_static_param
(
param
)
or
_is_non_int8_nvrtc_param
(
param
))
IMPLGEMM_AMPERE_PARAMS
=
_filter_params
(
IMPLGEMM_AMPERE_PARAMS
,
lambda
param
:
_is_ampere_no_int8_static_param
(
param
)
or
_is_non_int8_nvrtc_param
(
param
))
_clear_turing_volta
()
elif
_DTK_KERNEL_FILTER
==
"dtk_fp8_probe"
:
SHUFFLE_SIMT_PARAMS
=
_filter_params
(
SHUFFLE_SIMT_PARAMS
,
_is_fp32_simt_param
)
SHUFFLE_AMPERE_PARAMS
=
_filter_params
(
SHUFFLE_AMPERE_PARAMS
,
_is_fp8_param
)
IMPLGEMM_SIMT_PARAMS
=
_filter_params
(
IMPLGEMM_SIMT_PARAMS
,
_is_fp32_simt_param
)
IMPLGEMM_AMPERE_PARAMS
=
_filter_params
(
IMPLGEMM_AMPERE_PARAMS
,
_is_fp8_param
)
_clear_turing_volta
()
elif
_DTK_KERNEL_FILTER
==
"dtk_all_static_no_nvrtc"
:
SHUFFLE_SIMT_PARAMS
=
_filter_params
(
SHUFFLE_SIMT_PARAMS
,
_is_static_param
)
SHUFFLE_TURING_PARAMS
=
_filter_params
(
SHUFFLE_TURING_PARAMS
,
_is_static_param
)
SHUFFLE_VOLTA_PARAMS
=
_filter_params
(
SHUFFLE_VOLTA_PARAMS
,
_is_static_param
)
SHUFFLE_AMPERE_PARAMS
=
_filter_params
(
SHUFFLE_AMPERE_PARAMS
,
_is_static_param
)
IMPLGEMM_SIMT_PARAMS
=
_filter_params
(
IMPLGEMM_SIMT_PARAMS
,
_is_static_param
)
IMPLGEMM_TURING_PARAMS
=
_filter_params
(
IMPLGEMM_TURING_PARAMS
,
_is_static_param
)
IMPLGEMM_VOLTA_PARAMS
=
_filter_params
(
IMPLGEMM_VOLTA_PARAMS
,
_is_static_param
)
IMPLGEMM_AMPERE_PARAMS
=
_filter_params
(
IMPLGEMM_AMPERE_PARAMS
,
_is_static_param
)
elif
_DTK_KERNEL_FILTER
==
"dtk_all"
:
pass
ALL_NATIVE_PARAMS
=
SHUFFLE_SIMT_PARAMS
+
SHUFFLE_TURING_PARAMS
+
SHUFFLE_VOLTA_PARAMS
+
SHUFFLE_AMPERE_PARAMS
ALL_IMPGEMM_PARAMS
=
IMPLGEMM_SIMT_PARAMS
+
IMPLGEMM_TURING_PARAMS
+
IMPLGEMM_VOLTA_PARAMS
+
IMPLGEMM_AMPERE_PARAMS
spconv/csrc/sparse/all.py
View file @
3b545945
...
...
@@ -36,7 +36,7 @@ class CustomThrustLib(pccm.Class):
super
().
__init__
()
self
.
add_dependency
(
ThrustLib
)
# https://github.com/NVIDIA/thrust/issues/1401#issuecomment-806403746
if
compat
.
InLinux
:
if
compat
.
InLinux
and
os
.
getenv
(
"CUMM_DTK_DISABLE_INLINE_PTX"
,
"0"
)
!=
"1"
:
self
.
build_meta
.
add_public_cflags
(
"nvcc"
,
"-Xcompiler -fno-gnu-unique"
,
"-Xcompiler -fvisibility=hidden"
)
...
...
spconv/csrc/sparse/indices.py
View file @
3b545945
...
...
@@ -35,6 +35,7 @@ class CudaCommonKernel(pccm.ParameterizedClass):
@
pccm
.
cuda
.
cuda_global_function
def
arange_kernel
(
self
):
code
=
pccm
.
FunctionCode
()
code
.
add_pre_attr
(
"__launch_bounds__(1024)"
)
code
.
targ
(
"T"
)
code
.
arg
(
"data"
,
f
"T*"
)
code
.
arg
(
"size"
,
f
"int"
)
...
...
@@ -48,6 +49,7 @@ class CudaCommonKernel(pccm.ParameterizedClass):
@
pccm
.
cuda
.
cuda_global_function
def
fill_kernel
(
self
):
code
=
pccm
.
FunctionCode
()
code
.
add_pre_attr
(
"__launch_bounds__(1024)"
)
code
.
targ
(
"T"
)
code
.
arg
(
"data"
,
f
"T*"
)
code
.
arg
(
"val"
,
f
"T"
)
...
...
@@ -62,6 +64,7 @@ class CudaCommonKernel(pccm.ParameterizedClass):
@
pccm
.
cuda
.
cuda_global_function
def
maximum_value_kernel
(
self
):
code
=
pccm
.
FunctionCode
()
code
.
add_pre_attr
(
"__launch_bounds__(1024)"
)
code
.
targ
(
"T"
)
code
.
arg
(
"data"
,
f
"T*"
)
code
.
arg
(
"val"
,
f
"T"
)
...
...
@@ -723,6 +726,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@
pccm
.
cuda
.
cuda_global_function
def
build_subm_conv_hash_table
(
self
):
code
=
pccm
.
FunctionCode
()
code
.
add_pre_attr
(
"__launch_bounds__(1024)"
)
code
.
targ
(
"TTable"
)
code
.
targ
(
"TLayoutNPQ"
)
...
...
@@ -806,6 +810,7 @@ class SparseConvIndicesKernel(pccm.ParameterizedClass):
@
pccm
.
cuda
.
cuda_global_function
def
calc_subm_conv_indices_mask
(
self
):
code
=
pccm
.
FunctionCode
()
code
.
add_pre_attr
(
"__launch_bounds__(1024)"
)
code
.
targ
(
"TTable"
)
code
.
targ
(
"TConvLocIter"
)
code
.
arg
(
"loc_iter"
,
f
"TConvLocIter"
)
# [N, ndim + 1]
...
...
spconv/pytorch/cppcore.py
View file @
3b545945
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
from
cumm
import
tensorview
as
tv
import
os
import
torch
from
typing
import
Dict
,
Optional
,
List
,
Union
from
spconv.constants
import
AllocKeys
...
...
@@ -100,6 +101,11 @@ def get_current_stream():
def
get_arch
():
force_arch
=
os
.
getenv
(
"SPCONV_FORCE_CUDA_ARCH"
,
""
)
if
force_arch
:
force_arch
=
force_arch
.
replace
(
"."
,
""
)
arch
=
(
int
(
force_arch
[:
-
1
]),
int
(
force_arch
[
-
1
]))
else
:
arch
=
torch
.
cuda
.
get_device_capability
()
if
not
CompileInfo
.
arch_is_compatible
(
arch
)
and
not
CompileInfo
.
algo_can_use_ptx
((
0
,
0
),
arch
):
warnings
.
warn
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment