Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
bf34f040
Commit
bf34f040
authored
Sep 25, 2022
by
yan.yan
Browse files
fix build and nvrtc problem
parent
8c25ed52
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
129 additions
and
41 deletions
+129
-41
CHANGELOG.md
CHANGELOG.md
+5
-0
pyproject.toml
pyproject.toml
+1
-1
setup.py
setup.py
+2
-2
spconv/algo.py
spconv/algo.py
+33
-11
spconv/core_cc/csrc/sparse/convops/__init__.pyi
spconv/core_cc/csrc/sparse/convops/__init__.pyi
+0
-1
spconv/core_cc/cumm/common.pyi
spconv/core_cc/cumm/common.pyi
+39
-0
spconv/cppconstants.py
spconv/cppconstants.py
+2
-2
spconv/csrc/sparse/all.py
spconv/csrc/sparse/all.py
+2
-0
spconv/csrc/sparse/convops.py
spconv/csrc/sparse/convops.py
+34
-18
spconv/pytorch/cppcore.py
spconv/pytorch/cppcore.py
+6
-5
test/dev.py
test/dev.py
+4
-0
version.txt
version.txt
+1
-1
No files found.
CHANGELOG.md
View file @
bf34f040
# Changelog
# Changelog
## [2.2.1] - 2022-9-25
### Fixed
-
Fix build problem
-
Fix nvrtc problem
## [2.2.0] - 2022-9-24
## [2.2.0] - 2022-9-24
### Added
### Added
-
Add Ampere support. faster fp16, faster tf32 and greatly faster int8 kernels in Ampere GPUs.
-
Add Ampere support. faster fp16, faster tf32 and greatly faster int8 kernels in Ampere GPUs.
...
...
pyproject.toml
View file @
bf34f040
[build-system]
[build-system]
requires
=
[
"setuptools>=41.0"
,
"wheel"
,
"pccm>=0.4.0"
,
"cumm>=0.3.
1
"
]
requires
=
[
"setuptools>=41.0"
,
"wheel"
,
"pccm>=0.4.0"
,
"cumm>=0.3.
2
"
]
build-backend
=
"setuptools.build_meta"
build-backend
=
"setuptools.build_meta"
setup.py
View file @
bf34f040
...
@@ -38,9 +38,9 @@ if cuda_ver:
...
@@ -38,9 +38,9 @@ if cuda_ver:
cuda_ver
=
cuda_ver
.
replace
(
"."
,
""
)
# 10.2 to 102
cuda_ver
=
cuda_ver
.
replace
(
"."
,
""
)
# 10.2 to 102
RELEASE_NAME
+=
"-cu{}"
.
format
(
cuda_ver
)
RELEASE_NAME
+=
"-cu{}"
.
format
(
cuda_ver
)
deps
=
[
"cumm-cu{}>=0.3.
1
"
.
format
(
cuda_ver
)]
deps
=
[
"cumm-cu{}>=0.3.
2
"
.
format
(
cuda_ver
)]
else
:
else
:
deps
=
[
"cumm>=0.3.
1
"
]
deps
=
[
"cumm>=0.3.
2
"
]
...
...
spconv/algo.py
View file @
bf34f040
...
@@ -17,7 +17,7 @@ import time
...
@@ -17,7 +17,7 @@ import time
from
enum
import
Enum
from
enum
import
Enum
from
threading
import
Lock
from
threading
import
Lock
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
spconv.core_cc.cumm.common
import
CompileInfo
import
numpy
as
np
import
numpy
as
np
from
cumm
import
tensorview
as
tv
from
cumm
import
tensorview
as
tv
from
cumm.conv.bases
import
ConvLayout
,
ConvLayoutType
,
ConvOpType
from
cumm.conv.bases
import
ConvLayout
,
ConvLayoutType
,
ConvOpType
...
@@ -337,9 +337,20 @@ class SimpleGemm:
...
@@ -337,9 +337,20 @@ class SimpleGemm:
ldb
=
b
.
stride
[
0
]
ldb
=
b
.
stride
[
0
]
ldc
=
c
.
stride
[
0
]
ldc
=
c
.
stride
[
0
]
if
desp
.
supported_ldx
(
lda
,
ldb
,
ldc
):
if
desp
.
supported_ldx
(
lda
,
ldb
,
ldc
):
if
arch
not
in
COMPILED_CUDA_GEMM_ARCHS
:
if
desp
.
is_nvrtc
:
if
not
CompileInfo
.
algo_can_be_nvrtc_compiled
(
desp
.
min_arch
):
continue
if
not
CompileInfo
.
arch_is_compiled_gemm
(
arch
):
# use PTX of possible
if
not
CompileInfo
.
gemm_algo_can_use_ptx
(
desp
.
min_arch
,
arch
):
if
CompileInfo
.
algo_can_be_nvrtc_compiled
(
desp
.
min_arch
):
# compiled kernel can't use PTX, for example, desp need at least sm_80 and only sm_75+PTX is compiled
# all sm_80 code of this desp is invalid, we must use nvrtc.
# only desp <= sm_75 can use virtual PTX code to generate sm_80 code.
desp
=
desp
.
copy
()
desp
=
desp
.
copy
()
desp
.
is_nvrtc
=
True
desp
.
is_nvrtc
=
True
else
:
continue
if
SPCONV_DEBUG_NVRTC_KERNELS
:
if
SPCONV_DEBUG_NVRTC_KERNELS
:
desp
.
is_nvrtc
=
True
desp
.
is_nvrtc
=
True
finally_algos
.
append
(
desp
)
finally_algos
.
append
(
desp
)
...
@@ -455,7 +466,7 @@ class SimpleGemm:
...
@@ -455,7 +466,7 @@ class SimpleGemm:
if
desp
.
split_k_serial
and
hint
&
AlgoHint
.
BackwardWeight
.
value
:
if
desp
.
split_k_serial
and
hint
&
AlgoHint
.
BackwardWeight
.
value
:
split_k_slices
=
max
(
min
(
32
,
k
//
128
),
1
)
split_k_slices
=
max
(
min
(
32
,
k
//
128
),
1
)
params
=
GemmParams
()
params
=
GemmParams
()
if
desp
.
is_nvrtc
and
str
(
desp
)
not
in
self
.
prebuilt_desp_names
:
if
desp
.
is_nvrtc
or
str
(
desp
)
not
in
self
.
prebuilt_desp_names
:
params
.
nvrtc_params
=
self
.
_cached_get_nvrtc_params
(
desp
,
arch
)
params
.
nvrtc_params
=
self
.
_cached_get_nvrtc_params
(
desp
,
arch
)
params
.
a
=
a
params
.
a
=
a
params
.
b
=
b
params
.
b
=
b
...
@@ -550,7 +561,7 @@ class SimpleGemm:
...
@@ -550,7 +561,7 @@ class SimpleGemm:
split_k_slices
=
profile_res
.
splitk
split_k_slices
=
profile_res
.
splitk
params
=
GemmParams
()
params
=
GemmParams
()
is_not_static
=
str
(
algo_desp
)
not
in
self
.
prebuilt_desp_names
is_not_static
=
str
(
algo_desp
)
not
in
self
.
prebuilt_desp_names
if
algo_desp
.
is_nvrtc
and
(
is_not_static
or
force_nvrtc
)
:
if
algo_desp
.
is_nvrtc
or
is_not_static
or
force_nvrtc
:
params
.
nvrtc_params
=
self
.
_cached_get_nvrtc_params
(
params
.
nvrtc_params
=
self
.
_cached_get_nvrtc_params
(
algo_desp
,
profile_res
.
arch
)
algo_desp
,
profile_res
.
arch
)
...
@@ -720,9 +731,20 @@ class SimpleConv:
...
@@ -720,9 +731,20 @@ class SimpleConv:
assert
mask_width
>
0
assert
mask_width
>
0
mask_width_valid
=
mask_width
%
desp
.
tile_shape
[
2
]
==
0
mask_width_valid
=
mask_width
%
desp
.
tile_shape
[
2
]
==
0
if
desp
.
supported_ldx_conv
(
ldi
,
ldw
,
ldo
)
and
mask_width_valid
:
if
desp
.
supported_ldx_conv
(
ldi
,
ldw
,
ldo
)
and
mask_width_valid
:
if
arch
not
in
COMPILED_CUDA_GEMM_ARCHS
:
if
desp
.
is_nvrtc
:
if
not
CompileInfo
.
algo_can_be_nvrtc_compiled
(
desp
.
min_arch
):
continue
if
not
CompileInfo
.
arch_is_compiled_gemm
(
arch
):
# use PTX of possible
if
not
CompileInfo
.
gemm_algo_can_use_ptx
(
desp
.
min_arch
,
arch
):
if
CompileInfo
.
algo_can_be_nvrtc_compiled
(
desp
.
min_arch
):
# compiled kernel can't use PTX, for example, desp need at least sm_80 and only sm_75+PTX is compiled
# all sm_80 code of this desp is invalid, we must use nvrtc.
# only desp <= sm_75 can use virtual PTX code to generate sm_80 code.
desp
=
desp
.
copy
()
desp
=
desp
.
copy
()
desp
.
is_nvrtc
=
True
desp
.
is_nvrtc
=
True
else
:
continue
if
SPCONV_DEBUG_NVRTC_KERNELS
:
if
SPCONV_DEBUG_NVRTC_KERNELS
:
desp
.
is_nvrtc
=
True
desp
.
is_nvrtc
=
True
finally_algos
.
append
(
desp
)
finally_algos
.
append
(
desp
)
...
@@ -826,7 +848,7 @@ class SimpleConv:
...
@@ -826,7 +848,7 @@ class SimpleConv:
for
desp
in
avail
:
for
desp
in
avail
:
# for sparse conv, ndim isn't used, so we just provide a constant value.
# for sparse conv, ndim isn't used, so we just provide a constant value.
params
=
ConvParams
(
NDIM_DONT_CARE
,
ConvOpTypeCpp
(
op_type
.
value
))
params
=
ConvParams
(
NDIM_DONT_CARE
,
ConvOpTypeCpp
(
op_type
.
value
))
if
desp
.
is_nvrtc
and
str
(
desp
)
not
in
self
.
prebuilt_desp_names
:
if
desp
.
is_nvrtc
or
str
(
desp
)
not
in
self
.
prebuilt_desp_names
:
params
.
nvrtc_params
=
self
.
_cached_get_nvrtc_params
(
desp
,
arch
)
params
.
nvrtc_params
=
self
.
_cached_get_nvrtc_params
(
desp
,
arch
)
params
.
conv_algo_desp
=
desp
params
.
conv_algo_desp
=
desp
...
@@ -935,7 +957,7 @@ class SimpleConv:
...
@@ -935,7 +957,7 @@ class SimpleConv:
params
=
ConvParams
(
NDIM_DONT_CARE
,
ConvOpTypeCpp
(
op_type_value
))
params
=
ConvParams
(
NDIM_DONT_CARE
,
ConvOpTypeCpp
(
op_type_value
))
is_not_static
=
str
(
is_not_static
=
str
(
algo_desp
)
not
in
self
.
prebuilt_desp_names
algo_desp
)
not
in
self
.
prebuilt_desp_names
if
force_nvrtc
or
(
algo_desp
.
is_nvrtc
and
is_not_static
)
:
if
force_nvrtc
or
algo_desp
.
is_nvrtc
or
is_not_static
:
params
.
nvrtc_params
=
self
.
_cached_get_nvrtc_params
(
params
.
nvrtc_params
=
self
.
_cached_get_nvrtc_params
(
algo_desp
,
profile_res
.
arch
)
algo_desp
,
profile_res
.
arch
)
params
.
conv_algo_desp
=
profile_res
.
algo_desp
params
.
conv_algo_desp
=
profile_res
.
algo_desp
...
...
spconv/core_cc/csrc/sparse/convops/__init__.pyi
View file @
bf34f040
...
@@ -3,7 +3,6 @@ from pccm.stubs import EnumValue, EnumClassValue
...
@@ -3,7 +3,6 @@ from pccm.stubs import EnumValue, EnumClassValue
from cumm.tensorview.gemm import GemmAlgoDesp
from cumm.tensorview.gemm import GemmAlgoDesp
from cumm.tensorview.gemm import ConvAlgoDesp
from cumm.tensorview.gemm import ConvAlgoDesp
from cumm.tensorview import Tensor
from cumm.tensorview import Tensor
from ...csrc.sparse.convops import ExternalSpconvMatmul
class GemmTuneResult:
class GemmTuneResult:
algo_desp: GemmAlgoDesp
algo_desp: GemmAlgoDesp
arch: Tuple[int, int]
arch: Tuple[int, int]
...
...
spconv/core_cc/cumm/common.pyi
View file @
bf34f040
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue
class CompileInfo:
class CompileInfo:
@staticmethod
def get_compiled_cuda_version() -> Tuple[int, int]: ...
@staticmethod
@staticmethod
def get_compiled_cuda_arch() -> List[Tuple[int, int]]: ...
def get_compiled_cuda_arch() -> List[Tuple[int, int]]: ...
@staticmethod
@staticmethod
...
@@ -19,3 +21,40 @@ class CompileInfo:
...
@@ -19,3 +21,40 @@ class CompileInfo:
arch:
arch:
"""
"""
...
...
@staticmethod
def arch_is_compatible(arch: Tuple[int, int]) -> bool:
"""
Args:
arch:
"""
...
@staticmethod
def arch_is_compatible_gemm(arch: Tuple[int, int]) -> bool:
"""
Args:
arch:
"""
...
@staticmethod
def algo_can_use_ptx(min_arch: Tuple[int, int], arch: Tuple[int, int]) -> bool:
"""
Args:
min_arch:
arch:
"""
...
@staticmethod
def gemm_algo_can_use_ptx(min_arch: Tuple[int, int], arch: Tuple[int, int]) -> bool:
"""
Args:
min_arch:
arch:
"""
...
@staticmethod
def algo_can_be_nvrtc_compiled(min_arch: Tuple[int, int]) -> bool:
"""
Args:
min_arch:
"""
...
spconv/cppconstants.py
View file @
bf34f040
...
@@ -14,13 +14,13 @@
...
@@ -14,13 +14,13 @@
import
spconv.core_cc
as
_ext
import
spconv.core_cc
as
_ext
from
spconv.core_cc.csrc.sparse.all
import
SpconvOps
from
spconv.core_cc.csrc.sparse.all
import
SpconvOps
from
spconv.core_cc.csrc.utils.boxops
import
BoxOps
from
spconv.core_cc.cumm.common
import
CompileInfo
CPU_ONLY_BUILD
=
SpconvOps
.
is_cpu_only_build
()
CPU_ONLY_BUILD
=
SpconvOps
.
is_cpu_only_build
()
BUILD_CUMM_VERSION
=
SpconvOps
.
cumm_version
()
BUILD_CUMM_VERSION
=
SpconvOps
.
cumm_version
()
BUILD_PCCM_VERSION
=
SpconvOps
.
pccm_version
()
BUILD_PCCM_VERSION
=
SpconvOps
.
pccm_version
()
from
spconv.core_cc.csrc.utils.boxops
import
BoxOps
from
spconv.core_cc.cumm.common
import
CompileInfo
HAS_BOOST
=
BoxOps
.
has_boost
()
HAS_BOOST
=
BoxOps
.
has_boost
()
COMPILED_CUDA_ARCHS
=
set
(
CompileInfo
.
get_compiled_cuda_arch
())
COMPILED_CUDA_ARCHS
=
set
(
CompileInfo
.
get_compiled_cuda_arch
())
...
...
spconv/csrc/sparse/all.py
View file @
bf34f040
...
@@ -131,6 +131,8 @@ class SpconvOps(pccm.Class):
...
@@ -131,6 +131,8 @@ class SpconvOps(pccm.Class):
define_str
=
"
\n
"
.
join
(
defines
)
define_str
=
"
\n
"
.
join
(
defines
)
self
.
add_global_code
(
define_str
)
self
.
add_global_code
(
define_str
)
self
.
build_meta
.
add_global_cflags
(
"cl"
,
"/DNOMINMAX"
)
self
.
build_meta
.
add_global_cflags
(
"cl"
,
"/DNOMINMAX"
)
# self.build_meta.add_global_cflags("nvcc", "-w")
# for name in dir(AllocKeys):
# for name in dir(AllocKeys):
# if not name.startswith("__"):
# if not name.startswith("__"):
# v = getattr(AllocKeys, name)
# v = getattr(AllocKeys, name)
...
...
spconv/csrc/sparse/convops.py
View file @
bf34f040
...
@@ -550,7 +550,6 @@ class GemmTunerSimple(pccm.ParameterizedClass):
...
@@ -550,7 +550,6 @@ class GemmTunerSimple(pccm.ParameterizedClass):
}}
}}
// auto avail_algos = get_available_algo_str_from_arch(arch);
// auto avail_algos = get_available_algo_str_from_arch(arch);
std::vector<tv::gemm::GemmAlgoDesp> finally_algos;
std::vector<tv::gemm::GemmAlgoDesp> finally_algos;
auto is_arch_compiled = CompileInfo::arch_is_compiled_gemm(arch);
static_key_t static_key = std::make_tuple(trans_a, trans_b, trans_c, int(a.dtype()),
static_key_t static_key = std::make_tuple(trans_a, trans_b, trans_c, int(a.dtype()),
int(b.dtype()), int(c.dtype()), shuffle_type);
int(b.dtype()), int(c.dtype()), shuffle_type);
if (static_key_to_desps_.find(static_key) == static_key_to_desps_.end()){{
if (static_key_to_desps_.find(static_key) == static_key_to_desps_.end()){{
...
@@ -574,14 +573,23 @@ class GemmTunerSimple(pccm.ParameterizedClass):
...
@@ -574,14 +573,23 @@ class GemmTunerSimple(pccm.ParameterizedClass):
auto ldb = b.stride(0);
auto ldb = b.stride(0);
auto ldc = c.stride(0);
auto ldc = c.stride(0);
if (desp.supported_ldx(lda, ldb, ldc)){{
if (desp.supported_ldx(lda, ldb, ldc)){{
if (!is_arch_compiled){{
if (desp.is_nvrtc){{
if (!CompileInfo::algo_can_be_nvrtc_compiled(desp.min_arch)){{
continue;
}}
}}
if (!CompileInfo::arch_is_compiled_gemm(arch)){{
if (!CompileInfo::gemm_algo_can_use_ptx(desp.min_arch, arch)){{
if (CompileInfo::algo_can_be_nvrtc_compiled(desp.min_arch)){{
auto desp2 = desp;
auto desp2 = desp;
desp2.is_nvrtc = true;
desp2.is_nvrtc = true;
finally_algos.push_back(desp2);
}}else{{
}}else{{
finally_algos.push_back(desp);
continue;
}}
}}
}}
}}
}}
finally_algos.push_back(desp);
}}
}}
}}
return finally_algos;
return finally_algos;
"""
)
"""
)
...
@@ -699,7 +707,7 @@ class GemmTunerSimple(pccm.ParameterizedClass):
...
@@ -699,7 +707,7 @@ class GemmTunerSimple(pccm.ParameterizedClass):
for (auto& desp : avail){{
for (auto& desp : avail){{
tv::gemm::GemmParams params;
tv::gemm::GemmParams params;
if (desp.is_nvrtc
&&
prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end()){{
if (desp.is_nvrtc
||
prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end()){{
params.nvrtc_params = cached_get_nvrtc_params(desp, arch, stream_int);
params.nvrtc_params = cached_get_nvrtc_params(desp, arch, stream_int);
}}
}}
params.a = a;
params.a = a;
...
@@ -865,7 +873,7 @@ class GemmTunerSimple(pccm.ParameterizedClass):
...
@@ -865,7 +873,7 @@ class GemmTunerSimple(pccm.ParameterizedClass):
tv::gemm::GemmParams params;
tv::gemm::GemmParams params;
bool desp_is_static = prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end();
bool desp_is_static = prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end();
if (force_nvrtc || (desp.is_nvrtc
&&
desp_is_static)){{
if (force_nvrtc || (desp.is_nvrtc
||
desp_is_static)){{
params.nvrtc_params = cached_get_nvrtc_params(desp, profile_res.arch, stream_int);
params.nvrtc_params = cached_get_nvrtc_params(desp, profile_res.arch, stream_int);
}}
}}
params.a = a;
params.a = a;
...
@@ -1008,7 +1016,6 @@ class ConvTunerSimple(pccm.ParameterizedClass):
...
@@ -1008,7 +1016,6 @@ class ConvTunerSimple(pccm.ParameterizedClass):
use_f32_as_accum = false;
use_f32_as_accum = false;
std::vector<tv::gemm::ConvAlgoDesp> finally_algos;
std::vector<tv::gemm::ConvAlgoDesp> finally_algos;
auto is_arch_compiled = CompileInfo::arch_is_compiled_gemm(arch);
static_key_t static_key = std::make_tuple(
static_key_t static_key = std::make_tuple(
layout_i, layout_w, layout_o,
layout_i, layout_w, layout_o,
interleave_i, interleave_w, interleave_o, inp.dtype(),
interleave_i, interleave_w, interleave_o, inp.dtype(),
...
@@ -1053,14 +1060,23 @@ class ConvTunerSimple(pccm.ParameterizedClass):
...
@@ -1053,14 +1060,23 @@ class ConvTunerSimple(pccm.ParameterizedClass):
mask_width_valid = mask_width % desp.tile_shape[2] == 0;
mask_width_valid = mask_width % desp.tile_shape[2] == 0;
}}
}}
if (desp.supported_ldx_conv(ldi, ldw, ldo) && mask_width_valid){{
if (desp.supported_ldx_conv(ldi, ldw, ldo) && mask_width_valid){{
if (!is_arch_compiled){{
if (desp.is_nvrtc){{
if (!CompileInfo::algo_can_be_nvrtc_compiled(desp.min_arch)){{
continue;
}}
}}
if (!CompileInfo::arch_is_compiled_gemm(arch)){{
if (!CompileInfo::gemm_algo_can_use_ptx(desp.min_arch, arch)){{
if (CompileInfo::algo_can_be_nvrtc_compiled(desp.min_arch)){{
auto desp2 = desp;
auto desp2 = desp;
desp2.is_nvrtc = true;
desp2.is_nvrtc = true;
finally_algos.push_back(desp2);
}}else{{
}}else{{
finally_algos.push_back(desp);
continue;
}}
}}
}}
}}
}}
finally_algos.push_back(desp);
}}
}}
}}
return finally_algos;
return finally_algos;
"""
)
"""
)
...
@@ -1134,7 +1150,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
...
@@ -1134,7 +1150,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
tv::gemm::ConvOpType op_type_cpp = static_cast<tv::gemm::ConvOpType>(op_type);
tv::gemm::ConvOpType op_type_cpp = static_cast<tv::gemm::ConvOpType>(op_type);
for (auto& desp : avail){{
for (auto& desp : avail){{
tv::gemm::ConvParams params(
{
NDIM_DONT_CARE
}
, op_type_cpp, tv::CUDAKernelTimer(false));
tv::gemm::ConvParams params(
{
NDIM_DONT_CARE
}
, op_type_cpp, tv::CUDAKernelTimer(false));
if (desp.is_nvrtc
&&
prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end()){{
if (desp.is_nvrtc
||
prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end()){{
params.nvrtc_params = cached_get_nvrtc_params(desp, arch, stream_int);
params.nvrtc_params = cached_get_nvrtc_params(desp, arch, stream_int);
}}
}}
params.conv_algo_desp = desp;
params.conv_algo_desp = desp;
...
@@ -1311,7 +1327,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
...
@@ -1311,7 +1327,7 @@ class ConvTunerSimple(pccm.ParameterizedClass):
auto arch = profile_res.arch;
auto arch = profile_res.arch;
tv::gemm::ConvParams params(
{
NDIM_DONT_CARE
}
, op_type_cpp, timer);
tv::gemm::ConvParams params(
{
NDIM_DONT_CARE
}
, op_type_cpp, timer);
bool desp_is_static = prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end();
bool desp_is_static = prebuilt_names_.find(desp.__repr__()) == prebuilt_names_.end();
if (force_nvrtc || (desp.is_nvrtc
&&
desp_is_static)){{
if (force_nvrtc || (desp.is_nvrtc
||
desp_is_static)){{
params.nvrtc_params = cached_get_nvrtc_params(desp, arch, stream_int);
params.nvrtc_params = cached_get_nvrtc_params(desp, arch, stream_int);
}}
}}
params.conv_algo_desp = desp;
params.conv_algo_desp = desp;
...
...
spconv/pytorch/cppcore.py
View file @
bf34f040
...
@@ -20,6 +20,8 @@ from spconv.cppconstants import COMPILED_CUDA_ARCHS
...
@@ -20,6 +20,8 @@ from spconv.cppconstants import COMPILED_CUDA_ARCHS
import
sys
import
sys
from
spconv.core_cc.csrc.sparse.alloc
import
ExternalAllocator
from
spconv.core_cc.csrc.sparse.alloc
import
ExternalAllocator
from
spconv.core_cc.csrc.sparse.convops
import
ExternalSpconvMatmul
from
spconv.core_cc.csrc.sparse.convops
import
ExternalSpconvMatmul
from
spconv.core_cc.cumm.common
import
CompileInfo
import
warnings
import
numpy
as
np
import
numpy
as
np
...
@@ -93,12 +95,11 @@ def get_current_stream():
...
@@ -93,12 +95,11 @@ def get_current_stream():
def
get_arch
():
def
get_arch
():
arch
=
torch
.
cuda
.
get_device_capability
()
arch
=
torch
.
cuda
.
get_device_capability
()
if
arch
not
in
COMPILED_CUDA_ARCHS
:
if
not
CompileInfo
.
arch_is_compatible
(
arch
)
and
not
CompileInfo
.
algo_can_use_ptx
((
0
,
0
),
arch
)
:
print
(
warnings
.
warn
(
f
"[WARNING]your gpu arch
{
arch
}
isn't compiled in prebuilt, "
f
"[WARNING]your gpu arch
{
arch
}
isn't compiled in prebuilt, "
f
"may cause invalid device function. "
f
"may cause invalid device function error. "
f
"available:
{
COMPILED_CUDA_ARCHS
}
"
,
f
"available:
{
COMPILED_CUDA_ARCHS
}
"
)
file
=
sys
.
stderr
)
return
arch
return
arch
...
...
test/dev.py
View file @
bf34f040
import
spconv
import
spconv
from
spconv.core_cc.cumm.common
import
CompileInfo
if
__name__
==
"__main__"
:
print
(
CompileInfo
.
arch_is_compatible_gemm
((
9
,
0
)),
CompileInfo
.
arch_is_compiled_gemm
((
9
,
0
)))
print
(
CompileInfo
.
arch_is_compatible_gemm
((
8
,
6
)),
CompileInfo
.
arch_is_compiled_gemm
((
8
,
6
)))
\ No newline at end of file
version.txt
View file @
bf34f040
2.2.
0
2.2.
1
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment