Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
7f91c408
Commit
7f91c408
authored
Oct 10, 2020
by
yan.yan
Browse files
fix cuda 11 build
parent
42d92ee8
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
82 additions
and
58 deletions
+82
-58
include/tensorview/common.h
include/tensorview/common.h
+17
-0
include/tensorview/tensor.h
include/tensorview/tensor.h
+15
-15
include/tensorview/torch_utils.h
include/tensorview/torch_utils.h
+4
-4
setup.py
setup.py
+6
-1
src/spconv/indice.cc
src/spconv/indice.cc
+4
-4
src/spconv/indice.cu
src/spconv/indice.cu
+14
-14
src/spconv/maxpool.cc
src/spconv/maxpool.cc
+4
-4
src/spconv/maxpool.cu
src/spconv/maxpool.cu
+4
-4
src/spconv/reordering.cc
src/spconv/reordering.cc
+4
-4
src/spconv/reordering.cu
src/spconv/reordering.cu
+8
-8
third_party/cutlass
third_party/cutlass
+1
-0
third_party/mp11
third_party/mp11
+1
-0
No files found.
include/tensorview/common.h
View file @
7f91c408
...
@@ -24,6 +24,23 @@
...
@@ -24,6 +24,23 @@
#endif
#endif
#include <boost/stacktrace.hpp>
#include <boost/stacktrace.hpp>
#endif
#endif
#ifdef TV_CUDA
#include <cuda.h>
#endif
#if defined(TV_USE_BOOST_TYPEOF) || (!defined(__clang__) && defined(CUDA_VERSION) && CUDA_VERSION >= 11000)
// a workaround when built with cuda 11
// two options: use BOOST_TYPEOF or identity_t.
// this is a nvcc bug, msvc/gcc/clang don't have this problem.
// #include <boost/typeof/typeof.hpp>
// #define TV_DECLTYPE(x) BOOST_TYPEOF(x)
namespace
tv
{
template
<
typename
T
>
using
identity_t
=
T
;
}
#define TV_DECLTYPE(x) tv::identity_t<decltype(x)>
#else
#define TV_DECLTYPE(x) decltype(x)
#endif
namespace
tv
{
namespace
tv
{
...
...
include/tensorview/tensor.h
View file @
7f91c408
...
@@ -318,8 +318,8 @@ template <class... Ts, typename F> bool dispatch_noexcept(DType t, F &&f) {
...
@@ -318,8 +318,8 @@ template <class... Ts, typename F> bool dispatch_noexcept(DType t, F &&f) {
static_assert
(
sizeof
...(
Ts
)
>
0
,
"you need to provide at least one type"
);
static_assert
(
sizeof
...(
Ts
)
>
0
,
"you need to provide at least one type"
);
bool
notFound
=
true
;
bool
notFound
=
true
;
mp_for_each
<
mp_list
<
Ts
...
>>
([
=
,
&
notFound
,
&
f
](
auto
I
)
{
mp_for_each
<
mp_list
<
Ts
...
>>
([
=
,
&
notFound
,
&
f
](
auto
I
)
{
if
(
type_v
<
decltype
(
I
)
>
==
t
&&
notFound
)
{
if
(
type_v
<
TV_DECLTYPE
(
I
)
>
==
t
&&
notFound
)
{
std
::
forward
<
F
>
(
f
)(
decltype
(
I
)());
std
::
forward
<
F
>
(
f
)(
TV_DECLTYPE
(
I
)());
notFound
=
false
;
notFound
=
false
;
}
}
});
});
...
@@ -330,7 +330,7 @@ template <class... Ts, typename F> void dispatch(DType t, F &&f) {
...
@@ -330,7 +330,7 @@ template <class... Ts, typename F> void dispatch(DType t, F &&f) {
if
(
!
dispatch_noexcept
<
Ts
...
>
(
t
,
std
::
forward
<
F
>
(
f
)))
{
if
(
!
dispatch_noexcept
<
Ts
...
>
(
t
,
std
::
forward
<
F
>
(
f
)))
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
mp_for_each
<
mp_list
<
Ts
...
>>
([
=
,
&
ss
](
auto
I
)
{
mp_for_each
<
mp_list
<
Ts
...
>>
([
=
,
&
ss
](
auto
I
)
{
ss
<<
detail
::
TypeToString
<
decltype
(
I
)
>::
value
<<
" "
;
ss
<<
detail
::
TypeToString
<
TV_DECLTYPE
(
I
)
>::
value
<<
" "
;
});
});
TV_THROW_RT_ERR
(
"unknown type"
,
detail
::
typeString
(
t
),
TV_THROW_RT_ERR
(
"unknown type"
,
detail
::
typeString
(
t
),
", available:"
,
ss
.
str
());
", available:"
,
ss
.
str
());
...
@@ -359,7 +359,7 @@ template <int... Is, typename F> bool dispatch_int_noexcept(int idx, F &&f) {
...
@@ -359,7 +359,7 @@ template <int... Is, typename F> bool dispatch_int_noexcept(int idx, F &&f) {
"you need to provide at least one candidate"
);
"you need to provide at least one candidate"
);
bool
notFound
=
true
;
bool
notFound
=
true
;
mp_for_each
<
mp_list_c
<
int
,
Is
...
>>
([
=
,
&
notFound
,
&
f
](
auto
I
)
{
mp_for_each
<
mp_list_c
<
int
,
Is
...
>>
([
=
,
&
notFound
,
&
f
](
auto
I
)
{
if
(
decltype
(
I
)
::
value
==
idx
&&
notFound
)
{
if
(
TV_DECLTYPE
(
I
)
::
value
==
idx
&&
notFound
)
{
std
::
forward
<
F
>
(
f
)(
I
);
std
::
forward
<
F
>
(
f
)(
I
);
notFound
=
false
;
notFound
=
false
;
}
}
...
@@ -373,7 +373,7 @@ bool dispatch_int_noexcept(int idx, BinaryPredicate p, F &&f) {
...
@@ -373,7 +373,7 @@ bool dispatch_int_noexcept(int idx, BinaryPredicate p, F &&f) {
"you need to provide at least one candidate"
);
"you need to provide at least one candidate"
);
bool
notFound
=
true
;
bool
notFound
=
true
;
mp_for_each
<
mp_list_c
<
int
,
Is
...
>>
([
=
,
&
notFound
,
&
f
](
auto
I
)
{
mp_for_each
<
mp_list_c
<
int
,
Is
...
>>
([
=
,
&
notFound
,
&
f
](
auto
I
)
{
if
(
p
(
idx
,
decltype
(
I
)
::
value
)
&&
notFound
)
{
if
(
p
(
idx
,
TV_DECLTYPE
(
I
)
::
value
)
&&
notFound
)
{
std
::
forward
<
F
>
(
f
)(
I
);
std
::
forward
<
F
>
(
f
)(
I
);
notFound
=
false
;
notFound
=
false
;
}
}
...
@@ -385,7 +385,7 @@ template <int... Is, typename F> void dispatch_int(int idx, F &&f) {
...
@@ -385,7 +385,7 @@ template <int... Is, typename F> void dispatch_int(int idx, F &&f) {
if
(
!
dispatch_int_noexcept
<
Is
...
>
(
idx
,
std
::
forward
<
F
>
(
f
)))
{
if
(
!
dispatch_int_noexcept
<
Is
...
>
(
idx
,
std
::
forward
<
F
>
(
f
)))
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
mp_for_each
<
mp_list_c
<
int
,
Is
...
>>
(
mp_for_each
<
mp_list_c
<
int
,
Is
...
>>
(
[
=
,
&
ss
](
auto
I
)
{
ss
<<
decltype
(
I
)
::
value
<<
" "
;
});
[
=
,
&
ss
](
auto
I
)
{
ss
<<
TV_DECLTYPE
(
I
)
::
value
<<
" "
;
});
TV_THROW_RT_ERR
(
"unknown value"
,
idx
,
", available:"
,
ss
.
str
());
TV_THROW_RT_ERR
(
"unknown value"
,
idx
,
", available:"
,
ss
.
str
());
}
}
}
}
...
@@ -396,7 +396,7 @@ void dispatch_int(int idx, BinaryPredicate p, F &&f) {
...
@@ -396,7 +396,7 @@ void dispatch_int(int idx, BinaryPredicate p, F &&f) {
if
(
!
dispatch_int_noexcept
<
Is
...
>
(
idx
,
p
,
std
::
forward
<
F
>
(
f
)))
{
if
(
!
dispatch_int_noexcept
<
Is
...
>
(
idx
,
p
,
std
::
forward
<
F
>
(
f
)))
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
mp_for_each
<
mp_list_c
<
int
,
Is
...
>>
(
mp_for_each
<
mp_list_c
<
int
,
Is
...
>>
(
[
=
,
&
ss
](
auto
I
)
{
ss
<<
decltype
(
I
)
::
value
<<
" "
;
});
[
=
,
&
ss
](
auto
I
)
{
ss
<<
TV_DECLTYPE
(
I
)
::
value
<<
" "
;
});
TV_THROW_RT_ERR
(
"unknown value"
,
idx
,
", available:"
,
ss
.
str
());
TV_THROW_RT_ERR
(
"unknown value"
,
idx
,
", available:"
,
ss
.
str
());
}
}
}
}
...
@@ -408,7 +408,7 @@ bool dispatch_container_noexcept(Iterator begin, Iterator end, F &&f) {
...
@@ -408,7 +408,7 @@ bool dispatch_container_noexcept(Iterator begin, Iterator end, F &&f) {
"you need to provide at least one candidate"
);
"you need to provide at least one candidate"
);
bool
notFound
=
true
;
bool
notFound
=
true
;
mp_for_each
<
mp_list
<
Ts
...
>>
([
=
,
&
notFound
,
&
f
](
auto
I
)
{
mp_for_each
<
mp_list
<
Ts
...
>>
([
=
,
&
notFound
,
&
f
](
auto
I
)
{
using
val_lst_t
=
decltype
(
I
);
using
val_lst_t
=
TV_DECLTYPE
(
I
);
auto
val_lst_size
=
mp_size
<
val_lst_t
>::
value
;
auto
val_lst_size
=
mp_size
<
val_lst_t
>::
value
;
bool
equal
=
true
;
bool
equal
=
true
;
std
::
size_t
count
=
0
;
std
::
size_t
count
=
0
;
...
@@ -420,7 +420,7 @@ bool dispatch_container_noexcept(Iterator begin, Iterator end, F &&f) {
...
@@ -420,7 +420,7 @@ bool dispatch_container_noexcept(Iterator begin, Iterator end, F &&f) {
if
(
count
>=
val_lst_size
)
{
if
(
count
>=
val_lst_size
)
{
TV_THROW_INVALID_ARG
(
"iterator length invalid:"
,
val_lst_size
);
TV_THROW_INVALID_ARG
(
"iterator length invalid:"
,
val_lst_size
);
}
}
constexpr
auto
c
=
decltype
(
E
)
::
value
;
constexpr
auto
c
=
TV_DECLTYPE
(
E
)
::
value
;
if
(
c
!=
*
iter
)
{
if
(
c
!=
*
iter
)
{
equal
=
false
;
equal
=
false
;
}
}
...
@@ -450,8 +450,8 @@ void dispatch_container(Iterator begin, Iterator end, F &&f) {
...
@@ -450,8 +450,8 @@ void dispatch_container(Iterator begin, Iterator end, F &&f) {
ss
<<
"], available: "
;
ss
<<
"], available: "
;
mp_for_each
<
mp_list
<
Ts
...
>>
([
=
,
&
ss
](
auto
I
)
{
mp_for_each
<
mp_list
<
Ts
...
>>
([
=
,
&
ss
](
auto
I
)
{
ss
<<
"["
;
ss
<<
"["
;
mp_for_each
<
decltype
(
I
)
>
(
mp_for_each
<
TV_DECLTYPE
(
I
)
>
(
[
=
,
&
ss
](
auto
E
)
{
ss
<<
decltype
(
E
)
::
value
<<
","
;
});
[
=
,
&
ss
](
auto
E
)
{
ss
<<
TV_DECLTYPE
(
E
)
::
value
<<
","
;
});
ss
<<
"]"
;
ss
<<
"]"
;
});
});
TV_THROW_RT_ERR
(
ss
.
str
());
TV_THROW_RT_ERR
(
ss
.
str
());
...
@@ -791,7 +791,7 @@ struct Tensor {
...
@@ -791,7 +791,7 @@ struct Tensor {
writable_check
();
writable_check
();
TV_ASSERT_RT_ERR
(
device
()
==
-
1
,
"error"
);
TV_ASSERT_RT_ERR
(
device
()
==
-
1
,
"error"
);
Dispatch
<
detail
::
all_tensor_types_t
>
()(
dtype_
,
[
&
](
auto
I
)
{
Dispatch
<
detail
::
all_tensor_types_t
>
()(
dtype_
,
[
&
](
auto
I
)
{
using
Treal
=
decltype
(
I
);
using
Treal
=
TV_DECLTYPE
(
I
);
if
(
std
::
is_convertible
<
T
,
Treal
>::
value
)
{
if
(
std
::
is_convertible
<
T
,
Treal
>::
value
)
{
auto
ptr
=
reinterpret_cast
<
Treal
*>
(
raw_data
());
auto
ptr
=
reinterpret_cast
<
Treal
*>
(
raw_data
());
std
::
fill
(
ptr
,
ptr
+
size
(),
Treal
(
value
));
std
::
fill
(
ptr
,
ptr
+
size
(),
Treal
(
value
));
...
@@ -940,9 +940,9 @@ struct Tensor {
...
@@ -940,9 +940,9 @@ struct Tensor {
TV_ASSERT_INVALID_ARG
(
contiguous_
,
"only support contiguous for now"
);
TV_ASSERT_INVALID_ARG
(
contiguous_
,
"only support contiguous for now"
);
auto
tensor
=
Tensor
();
auto
tensor
=
Tensor
();
Dispatch
<
detail
::
all_tensor_types_t
>
()(
dtype
,
[
&
](
auto
Idst
)
{
Dispatch
<
detail
::
all_tensor_types_t
>
()(
dtype
,
[
&
](
auto
Idst
)
{
using
Tdst
=
decltype
(
Idst
);
using
Tdst
=
TV_DECLTYPE
(
Idst
);
Dispatch
<
detail
::
all_tensor_types_t
>
()(
this
->
dtype_
,
[
&
](
auto
Icur
)
{
Dispatch
<
detail
::
all_tensor_types_t
>
()(
this
->
dtype_
,
[
&
](
auto
Icur
)
{
using
Tcur
=
decltype
(
Icur
);
using
Tcur
=
TV_DECLTYPE
(
Icur
);
if
(
std
::
is_convertible
<
Tcur
,
Tdst
>::
value
)
{
if
(
std
::
is_convertible
<
Tcur
,
Tdst
>::
value
)
{
auto
ptr
=
this
->
data
<
Tcur
>
();
auto
ptr
=
this
->
data
<
Tcur
>
();
tensor
=
Tensor
(
this
->
shape_
,
this
->
stride_
,
dtype
,
this
->
device
(),
tensor
=
Tensor
(
this
->
shape_
,
this
->
stride_
,
dtype
,
this
->
device
(),
...
@@ -981,7 +981,7 @@ private:
...
@@ -981,7 +981,7 @@ private:
template
<
typename
Os
>
Os
&
operator
<<
(
Os
&
os
,
const
Tensor
&
tensor
)
{
template
<
typename
Os
>
Os
&
operator
<<
(
Os
&
os
,
const
Tensor
&
tensor
)
{
TV_ASSERT_INVALID_ARG
(
tensor
.
device
()
==
-
1
,
"must be cpu tensor"
);
TV_ASSERT_INVALID_ARG
(
tensor
.
device
()
==
-
1
,
"must be cpu tensor"
);
Dispatch
<
detail
::
all_tensor_types_t
>
()(
tensor
.
dtype
(),
[
&
](
auto
I
)
{
Dispatch
<
detail
::
all_tensor_types_t
>
()(
tensor
.
dtype
(),
[
&
](
auto
I
)
{
using
T
=
decltype
(
I
);
using
T
=
TV_DECLTYPE
(
I
);
std
::
stringstream
ss
;
std
::
stringstream
ss
;
if
(
std
::
is_same
<
T
,
float
>::
value
||
std
::
is_same
<
T
,
double
>::
value
)
{
if
(
std
::
is_same
<
T
,
float
>::
value
||
std
::
is_same
<
T
,
double
>::
value
)
{
ss
<<
std
::
setprecision
(
4
);
ss
<<
std
::
setprecision
(
4
);
...
...
include/tensorview/torch_utils.h
View file @
7f91c408
...
@@ -76,15 +76,15 @@ void dispatch_torch(at::ScalarType t, F &&f) {
...
@@ -76,15 +76,15 @@ void dispatch_torch(at::ScalarType t, F &&f) {
static_assert
(
sizeof
...(
Ts
)
>
0
,
"you need to provide at least one type"
);
static_assert
(
sizeof
...(
Ts
)
>
0
,
"you need to provide at least one type"
);
bool
notFound
=
true
;
bool
notFound
=
true
;
tv
::
mp_for_each
<
mp_list
<
Ts
...
>>
([
=
,
&
notFound
,
&
f
](
auto
I
)
{
tv
::
mp_for_each
<
mp_list
<
Ts
...
>>
([
=
,
&
notFound
,
&
f
](
auto
I
)
{
if
(
detail
::
TypeToTorchDtypeTraits
<
decltype
(
I
)
>::
value
==
t
)
{
if
(
detail
::
TypeToTorchDtypeTraits
<
TV_DECLTYPE
(
I
)
>::
value
==
t
)
{
std
::
forward
<
F
>
(
f
)(
decltype
(
I
)());
std
::
forward
<
F
>
(
f
)(
TV_DECLTYPE
(
I
)());
notFound
=
false
;
notFound
=
false
;
}
}
});
});
if
(
notFound
)
{
if
(
notFound
)
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
tv
::
mp_for_each
<
mp_list
<
Ts
...
>>
([
=
,
&
ss
](
auto
I
)
{
tv
::
mp_for_each
<
mp_list
<
Ts
...
>>
([
=
,
&
ss
](
auto
I
)
{
ss
<<
tv
::
detail
::
TypeToString
<
decltype
(
I
)
>::
value
<<
" "
;
ss
<<
tv
::
detail
::
TypeToString
<
TV_DECLTYPE
(
I
)
>::
value
<<
" "
;
});
});
TV_THROW_RT_ERR
(
"unknown type"
,
t
,
", available:"
,
ss
.
str
());
TV_THROW_RT_ERR
(
"unknown type"
,
t
,
", available:"
,
ss
.
str
());
}
}
...
@@ -101,7 +101,7 @@ struct DispatchTorch<T<Args...>> {
...
@@ -101,7 +101,7 @@ struct DispatchTorch<T<Args...>> {
template
<
typename
T
>
void
check_torch_dtype
(
const
torch
::
Tensor
&
tensor
)
{
template
<
typename
T
>
void
check_torch_dtype
(
const
torch
::
Tensor
&
tensor
)
{
DispatchTorch
<
detail
::
all_torch_types_t
>
()(
tensor
.
scalar_type
(),
[
&
](
auto
I
)
{
DispatchTorch
<
detail
::
all_torch_types_t
>
()(
tensor
.
scalar_type
(),
[
&
](
auto
I
)
{
using
Ttensor
=
decltype
(
I
);
using
Ttensor
=
TV_DECLTYPE
(
I
);
constexpr
bool
val
=
std
::
is_same
<
std
::
remove_cv_t
<
T
>
,
Ttensor
>::
value
;
constexpr
bool
val
=
std
::
is_same
<
std
::
remove_cv_t
<
T
>
,
Ttensor
>::
value
;
TV_ASSERT_RT_ERR
(
val
,
"error"
);
TV_ASSERT_RT_ERR
(
val
,
"error"
);
});
});
...
...
setup.py
View file @
7f91c408
...
@@ -19,10 +19,15 @@ SPCONV_FORCE_BUILD_CUDA = os.getenv("SPCONV_FORCE_BUILD_CUDA")
...
@@ -19,10 +19,15 @@ SPCONV_FORCE_BUILD_CUDA = os.getenv("SPCONV_FORCE_BUILD_CUDA")
PYTHON_VERSION
=
"{}.{}"
.
format
(
sys
.
version_info
.
major
,
sys
.
version_info
.
minor
)
PYTHON_VERSION
=
"{}.{}"
.
format
(
sys
.
version_info
.
major
,
sys
.
version_info
.
minor
)
remove_plus
=
torch
.
__version__
.
find
(
"+"
)
remove_plus
=
torch
.
__version__
.
find
(
"+dev"
)
remove_dot
=
torch
.
__version__
.
find
(
".dev"
)
PYTORCH_VERSION
=
torch
.
__version__
PYTORCH_VERSION
=
torch
.
__version__
if
remove_plus
!=
-
1
:
if
remove_plus
!=
-
1
:
PYTORCH_VERSION
=
torch
.
__version__
[:
remove_plus
]
PYTORCH_VERSION
=
torch
.
__version__
[:
remove_plus
]
if
remove_dot
!=
-
1
:
PYTORCH_VERSION
=
torch
.
__version__
[:
remove_dot
]
PYTORCH_VERSION
=
list
(
map
(
int
,
PYTORCH_VERSION
.
split
(
"."
)))
PYTORCH_VERSION
=
list
(
map
(
int
,
PYTORCH_VERSION
.
split
(
"."
)))
PYTORCH_VERSION_NUMBER
=
PYTORCH_VERSION
[
0
]
*
10000
+
PYTORCH_VERSION
[
1
]
*
100
+
PYTORCH_VERSION
[
2
]
PYTORCH_VERSION_NUMBER
=
PYTORCH_VERSION
[
0
]
*
10000
+
PYTORCH_VERSION
[
1
]
*
100
+
PYTORCH_VERSION
[
2
]
...
...
src/spconv/indice.cc
View file @
7f91c408
...
@@ -268,10 +268,10 @@ int create_conv_indice_pair_cpu(
...
@@ -268,10 +268,10 @@ int create_conv_indice_pair_cpu(
if
(
numActIn
==
0
)
if
(
numActIn
==
0
)
return
0
;
return
0
;
tv
::
dispatch_torch
<
int32_t
,
int64_t
>
(
indicesIn
.
scalar_type
(),
[
&
](
auto
V
)
{
tv
::
dispatch_torch
<
int32_t
,
int64_t
>
(
indicesIn
.
scalar_type
(),
[
&
](
auto
V
)
{
using
Index
=
decltype
(
V
);
using
Index
=
TV_DECLTYPE
(
V
);
using
IndexGrid
=
int32_t
;
using
IndexGrid
=
int32_t
;
tv
::
dispatch_int
<
2
,
3
,
4
>
(
ndim
,
[
&
](
auto
I
)
{
tv
::
dispatch_int
<
2
,
3
,
4
>
(
ndim
,
[
&
](
auto
I
)
{
constexpr
int
NDim
=
decltype
(
I
)
::
value
;
constexpr
int
NDim
=
TV_DECLTYPE
(
I
)
::
value
;
tv
::
SimpleVector
<
Index
,
NDim
>
ks
(
kernelSize
.
begin
(),
kernelSize
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
ks
(
kernelSize
.
begin
(),
kernelSize
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
st
(
stride
.
begin
(),
stride
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
st
(
stride
.
begin
(),
stride
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
pa
(
padding
.
begin
(),
padding
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
pa
(
padding
.
begin
(),
padding
.
end
());
...
@@ -308,10 +308,10 @@ int create_submconv_indice_pair_cpu(
...
@@ -308,10 +308,10 @@ int create_submconv_indice_pair_cpu(
if
(
numActIn
==
0
)
if
(
numActIn
==
0
)
return
0
;
return
0
;
tv
::
dispatch_torch
<
int32_t
,
int64_t
>
(
indicesIn
.
scalar_type
(),
[
&
](
auto
V
)
{
tv
::
dispatch_torch
<
int32_t
,
int64_t
>
(
indicesIn
.
scalar_type
(),
[
&
](
auto
V
)
{
using
Index
=
decltype
(
V
);
using
Index
=
TV_DECLTYPE
(
V
);
using
IndexGrid
=
int32_t
;
using
IndexGrid
=
int32_t
;
tv
::
dispatch_int
<
2
,
3
,
4
>
(
ndim
,
[
&
](
auto
I
)
{
tv
::
dispatch_int
<
2
,
3
,
4
>
(
ndim
,
[
&
](
auto
I
)
{
constexpr
int
NDim
=
decltype
(
I
)
::
value
;
constexpr
int
NDim
=
TV_DECLTYPE
(
I
)
::
value
;
tv
::
SimpleVector
<
Index
,
NDim
>
ks
(
kernelSize
.
begin
(),
kernelSize
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
ks
(
kernelSize
.
begin
(),
kernelSize
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
st
(
stride
.
begin
(),
stride
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
st
(
stride
.
begin
(),
stride
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
pa
(
padding
.
begin
(),
padding
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
pa
(
padding
.
begin
(),
padding
.
end
());
...
...
src/spconv/indice.cu
View file @
7f91c408
...
@@ -45,10 +45,10 @@ int create_conv_indice_pair_p1_cuda(
...
@@ -45,10 +45,10 @@ int create_conv_indice_pair_p1_cuda(
if
(
numActIn
==
0
)
if
(
numActIn
==
0
)
return
0
;
return
0
;
tv
::
dispatch_torch
<
int32_t
>
(
indicesIn
.
scalar_type
(),
[
&
](
auto
IndexValue
)
{
tv
::
dispatch_torch
<
int32_t
>
(
indicesIn
.
scalar_type
(),
[
&
](
auto
IndexValue
)
{
using
Index
=
decltype
(
IndexValue
);
using
Index
=
TV_DECLTYPE
(
IndexValue
);
using
IndexGrid
=
int32_t
;
using
IndexGrid
=
int32_t
;
tv
::
dispatch_int
<
2
,
3
,
4
>
(
ndim
,
[
&
](
auto
I
)
{
tv
::
dispatch_int
<
2
,
3
,
4
>
(
ndim
,
[
&
](
auto
I
)
{
constexpr
int
NDim
=
decltype
(
I
)
::
value
;
constexpr
int
NDim
=
TV_DECLTYPE
(
I
)
::
value
;
tv
::
SimpleVector
<
Index
,
NDim
>
ks
(
kernelSize
.
begin
(),
kernelSize
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
ks
(
kernelSize
.
begin
(),
kernelSize
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
st
(
stride
.
begin
(),
stride
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
st
(
stride
.
begin
(),
stride
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
pa
(
padding
.
begin
(),
padding
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
pa
(
padding
.
begin
(),
padding
.
end
());
...
@@ -57,7 +57,7 @@ int create_conv_indice_pair_p1_cuda(
...
@@ -57,7 +57,7 @@ int create_conv_indice_pair_p1_cuda(
outSpatialShape
.
end
());
outSpatialShape
.
end
());
tv
::
DispatchInt
<
max_kernel_vol_t
>
()(
tv
::
DispatchInt
<
max_kernel_vol_t
>
()(
kernelVolume
,
std
::
less_equal
<
int
>
(),
[
&
](
auto
I2
)
{
kernelVolume
,
std
::
less_equal
<
int
>
(),
[
&
](
auto
I2
)
{
constexpr
int
MaxKernelVolume
=
decltype
(
I2
)
::
value
;
constexpr
int
MaxKernelVolume
=
TV_DECLTYPE
(
I2
)
::
value
;
if
(
transpose
)
{
if
(
transpose
)
{
prepareDeConvIndicePairsKernel
<
Index
,
NDim
,
MaxKernelVolume
>
prepareDeConvIndicePairsKernel
<
Index
,
NDim
,
MaxKernelVolume
>
<<<
tv
::
cuda
::
getBlocks
(
numActIn
),
tv
::
cuda
::
CUDA_NUM_THREADS
,
<<<
tv
::
cuda
::
getBlocks
(
numActIn
),
tv
::
cuda
::
CUDA_NUM_THREADS
,
...
@@ -106,10 +106,10 @@ int create_conv_indice_pair_p2_cuda(
...
@@ -106,10 +106,10 @@ int create_conv_indice_pair_p2_cuda(
if
(
numActIn
==
0
)
if
(
numActIn
==
0
)
return
0
;
return
0
;
tv
::
dispatch_torch
<
int32_t
>
(
indicesIn
.
scalar_type
(),
[
&
](
auto
IndexValue
)
{
tv
::
dispatch_torch
<
int32_t
>
(
indicesIn
.
scalar_type
(),
[
&
](
auto
IndexValue
)
{
using
Index
=
decltype
(
IndexValue
);
using
Index
=
TV_DECLTYPE
(
IndexValue
);
using
IndexGrid
=
int32_t
;
using
IndexGrid
=
int32_t
;
tv
::
dispatch_int
<
2
,
3
,
4
>
(
ndim
,
[
&
](
auto
I
)
{
tv
::
dispatch_int
<
2
,
3
,
4
>
(
ndim
,
[
&
](
auto
I
)
{
constexpr
int
NDim
=
decltype
(
I
)
::
value
;
constexpr
int
NDim
=
TV_DECLTYPE
(
I
)
::
value
;
using
IndexGrid
=
int32_t
;
using
IndexGrid
=
int32_t
;
tv
::
SimpleVector
<
Index
,
NDim
>
ou
(
outSpatialShape
.
begin
(),
tv
::
SimpleVector
<
Index
,
NDim
>
ou
(
outSpatialShape
.
begin
(),
outSpatialShape
.
end
());
outSpatialShape
.
end
());
...
@@ -212,10 +212,10 @@ int create_submconv_indice_pair_cuda(
...
@@ -212,10 +212,10 @@ int create_submconv_indice_pair_cuda(
if
(
numActIn
==
0
)
if
(
numActIn
==
0
)
return
0
;
return
0
;
tv
::
dispatch_torch
<
int32_t
>
(
indicesIn
.
scalar_type
(),
[
&
](
auto
IndexValue
)
{
tv
::
dispatch_torch
<
int32_t
>
(
indicesIn
.
scalar_type
(),
[
&
](
auto
IndexValue
)
{
using
Index
=
decltype
(
IndexValue
);
using
Index
=
TV_DECLTYPE
(
IndexValue
);
using
IndexGrid
=
int32_t
;
using
IndexGrid
=
int32_t
;
tv
::
dispatch_int
<
2
,
3
,
4
>
(
ndim
,
[
&
](
auto
I
)
{
tv
::
dispatch_int
<
2
,
3
,
4
>
(
ndim
,
[
&
](
auto
I
)
{
constexpr
int
NDim
=
decltype
(
I
)
::
value
;
constexpr
int
NDim
=
TV_DECLTYPE
(
I
)
::
value
;
tv
::
SimpleVector
<
Index
,
NDim
>
ks
(
kernelSize
.
begin
(),
kernelSize
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
ks
(
kernelSize
.
begin
(),
kernelSize
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
st
(
stride
.
begin
(),
stride
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
st
(
stride
.
begin
(),
stride
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
pa
(
padding
.
begin
(),
padding
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
pa
(
padding
.
begin
(),
padding
.
end
());
...
@@ -254,7 +254,7 @@ int create_submconv_indice_pair_cuda(
...
@@ -254,7 +254,7 @@ int create_submconv_indice_pair_cuda(
auto
stash_count
=
table
.
get_stash_count
();
auto
stash_count
=
table
.
get_stash_count
();
tv
::
DispatchInt
<
max_kernel_vol_t
>
()(
tv
::
DispatchInt
<
max_kernel_vol_t
>
()(
kernelVolume
,
std
::
less_equal
<
int
>
(),
[
&
](
auto
I2
)
{
kernelVolume
,
std
::
less_equal
<
int
>
(),
[
&
](
auto
I2
)
{
constexpr
int
MaxKernelVolume
=
decltype
(
I2
)
::
value
;
constexpr
int
MaxKernelVolume
=
TV_DECLTYPE
(
I2
)
::
value
;
getSubMIndicePairsHashKernel
<
Index
,
NDim
,
MaxKernelVolume
>
getSubMIndicePairsHashKernel
<
Index
,
NDim
,
MaxKernelVolume
>
<<<
tv
::
cuda
::
getBlocks
(
numActIn
),
tv
::
cuda
::
CUDA_NUM_THREADS
,
<<<
tv
::
cuda
::
getBlocks
(
numActIn
),
tv
::
cuda
::
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
tv
::
torch2tv
<
Index
>
(
indicesIn
),
0
,
stream
>>>
(
tv
::
torch2tv
<
Index
>
(
indicesIn
),
...
@@ -286,8 +286,8 @@ int create_submconv_indice_pair_cuda(
...
@@ -286,8 +286,8 @@ int create_submconv_indice_pair_cuda(
tv
::
dispatch_int_noexcept
<
1
,
3
,
5
>
(
kernelSize
[
0
],
[
&
](
auto
K0C
)
{
tv
::
dispatch_int_noexcept
<
1
,
3
,
5
>
(
kernelSize
[
0
],
[
&
](
auto
K0C
)
{
tv
::
dispatch_int_noexcept
<
1
,
3
,
5
>
(
kernelSize
[
1
],
[
&
](
auto
K1C
)
{
tv
::
dispatch_int_noexcept
<
1
,
3
,
5
>
(
kernelSize
[
1
],
[
&
](
auto
K1C
)
{
constexpr
int
K0
=
decltype
(
K0C
)
::
value
;
constexpr
int
K0
=
TV_DECLTYPE
(
K0C
)
::
value
;
constexpr
int
K1
=
decltype
(
K1C
)
::
value
;
constexpr
int
K1
=
TV_DECLTYPE
(
K1C
)
::
value
;
found
=
true
;
found
=
true
;
getSubMIndicePairsKernel2
<
Index
,
IndexGrid
,
K0
,
K1
>
getSubMIndicePairsKernel2
<
Index
,
IndexGrid
,
K0
,
K1
>
<<<
tv
::
cuda
::
getBlocks
(
numActIn
),
<<<
tv
::
cuda
::
getBlocks
(
numActIn
),
...
@@ -306,9 +306,9 @@ int create_submconv_indice_pair_cuda(
...
@@ -306,9 +306,9 @@ int create_submconv_indice_pair_cuda(
tv
::
dispatch_int_noexcept
<
1
,
3
,
5
>
(
kernelSize
[
1
],
[
&
](
auto
K1C
)
{
tv
::
dispatch_int_noexcept
<
1
,
3
,
5
>
(
kernelSize
[
1
],
[
&
](
auto
K1C
)
{
tv
::
dispatch_int_noexcept
<
1
,
3
,
5
>
(
tv
::
dispatch_int_noexcept
<
1
,
3
,
5
>
(
kernelSize
[
2
],
[
&
](
auto
K2C
)
{
kernelSize
[
2
],
[
&
](
auto
K2C
)
{
constexpr
int
K0
=
decltype
(
K0C
)
::
value
;
constexpr
int
K0
=
TV_DECLTYPE
(
K0C
)
::
value
;
constexpr
int
K1
=
decltype
(
K1C
)
::
value
;
constexpr
int
K1
=
TV_DECLTYPE
(
K1C
)
::
value
;
constexpr
int
K2
=
decltype
(
K2C
)
::
value
;
constexpr
int
K2
=
TV_DECLTYPE
(
K2C
)
::
value
;
found
=
true
;
found
=
true
;
getSubMIndicePairsKernel3
<
Index
,
IndexGrid
,
K0
,
K1
,
K2
>
getSubMIndicePairsKernel3
<
Index
,
IndexGrid
,
K0
,
K1
,
K2
>
<<<
tv
::
cuda
::
getBlocks
(
numActIn
),
<<<
tv
::
cuda
::
getBlocks
(
numActIn
),
...
@@ -326,7 +326,7 @@ int create_submconv_indice_pair_cuda(
...
@@ -326,7 +326,7 @@ int create_submconv_indice_pair_cuda(
if
(
!
found
)
{
if
(
!
found
)
{
tv
::
DispatchInt
<
tv
::
DispatchInt
<
max_kernel_vol_t
>
()(
ndim
,
std
::
less_equal
<
int
>
(),
[
&
](
auto
I2
)
{
max_kernel_vol_t
>
()(
ndim
,
std
::
less_equal
<
int
>
(),
[
&
](
auto
I2
)
{
constexpr
int
MaxKernelVolume
=
decltype
(
I2
)
::
value
;
constexpr
int
MaxKernelVolume
=
TV_DECLTYPE
(
I2
)
::
value
;
getSubMIndicePairsKernel
<
Index
,
IndexGrid
,
NDim
,
MaxKernelVolume
>
getSubMIndicePairsKernel
<
Index
,
IndexGrid
,
NDim
,
MaxKernelVolume
>
<<<
tv
::
cuda
::
getBlocks
(
numActIn
),
tv
::
cuda
::
CUDA_NUM_THREADS
,
0
,
<<<
tv
::
cuda
::
getBlocks
(
numActIn
),
tv
::
cuda
::
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
tv
::
torch2tv
<
Index
>
(
indicesIn
),
stream
>>>
(
tv
::
torch2tv
<
Index
>
(
indicesIn
),
...
...
src/spconv/maxpool.cc
View file @
7f91c408
...
@@ -29,9 +29,9 @@ void maxpool_fwd_cpu(torch::Tensor outFeatures, torch::Tensor inFeatures,
...
@@ -29,9 +29,9 @@ void maxpool_fwd_cpu(torch::Tensor outFeatures, torch::Tensor inFeatures,
auto
dtype
=
inFeatures
.
scalar_type
();
auto
dtype
=
inFeatures
.
scalar_type
();
auto
int_dtype
=
indicesIn
.
scalar_type
();
auto
int_dtype
=
indicesIn
.
scalar_type
();
tv
::
DispatchTorch
<
float_types_t
>
()(
dtype
,
[
&
](
auto
TValue
)
{
tv
::
DispatchTorch
<
float_types_t
>
()(
dtype
,
[
&
](
auto
TValue
)
{
using
T
=
decltype
(
TValue
);
using
T
=
TV_DECLTYPE
(
TValue
);
tv
::
DispatchTorch
<
int_types_t
>
()(
int_dtype
,
[
&
](
auto
IndexValue
)
{
tv
::
DispatchTorch
<
int_types_t
>
()(
int_dtype
,
[
&
](
auto
IndexValue
)
{
using
Index
=
decltype
(
IndexValue
);
using
Index
=
TV_DECLTYPE
(
IndexValue
);
auto
outFeaturesData
=
outFeatures
.
data_ptr
<
T
>
();
auto
outFeaturesData
=
outFeatures
.
data_ptr
<
T
>
();
auto
inFeaturesData
=
inFeatures
.
data_ptr
<
T
>
();
auto
inFeaturesData
=
inFeatures
.
data_ptr
<
T
>
();
auto
indicesInData
=
indicesIn
.
data_ptr
<
Index
>
();
auto
indicesInData
=
indicesIn
.
data_ptr
<
Index
>
();
...
@@ -58,9 +58,9 @@ void maxpool_bwd_cpu(torch::Tensor outFeatures, torch::Tensor inFeatures,
...
@@ -58,9 +58,9 @@ void maxpool_bwd_cpu(torch::Tensor outFeatures, torch::Tensor inFeatures,
auto
dtype
=
inFeatures
.
scalar_type
();
auto
dtype
=
inFeatures
.
scalar_type
();
auto
int_dtype
=
indicesIn
.
scalar_type
();
auto
int_dtype
=
indicesIn
.
scalar_type
();
tv
::
DispatchTorch
<
float_types_t
>
()(
dtype
,
[
&
](
auto
TValue
)
{
tv
::
DispatchTorch
<
float_types_t
>
()(
dtype
,
[
&
](
auto
TValue
)
{
using
T
=
decltype
(
TValue
);
using
T
=
TV_DECLTYPE
(
TValue
);
tv
::
DispatchTorch
<
int_types_t
>
()(
int_dtype
,
[
&
](
auto
IndexValue
)
{
tv
::
DispatchTorch
<
int_types_t
>
()(
int_dtype
,
[
&
](
auto
IndexValue
)
{
using
Index
=
decltype
(
IndexValue
);
using
Index
=
TV_DECLTYPE
(
IndexValue
);
auto
outFeaturesData
=
outFeatures
.
data_ptr
<
T
>
();
auto
outFeaturesData
=
outFeatures
.
data_ptr
<
T
>
();
auto
inFeaturesData
=
inFeatures
.
data_ptr
<
T
>
();
auto
inFeaturesData
=
inFeatures
.
data_ptr
<
T
>
();
auto
doutData
=
dout
.
data_ptr
<
T
>
();
auto
doutData
=
dout
.
data_ptr
<
T
>
();
...
...
src/spconv/maxpool.cu
View file @
7f91c408
...
@@ -320,13 +320,13 @@ void maxpool_fwd_cuda(torch::Tensor outFeatures, torch::Tensor inFeatures,
...
@@ -320,13 +320,13 @@ void maxpool_fwd_cuda(torch::Tensor outFeatures, torch::Tensor inFeatures,
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
tv
::
DispatchTorch
<
float_types_t
>
()(
dtype
,
[
&
](
auto
TValue
)
{
tv
::
DispatchTorch
<
float_types_t
>
()(
dtype
,
[
&
](
auto
TValue
)
{
using
T
=
decltype
(
TValue
);
using
T
=
TV_DECLTYPE
(
TValue
);
using
vecload_type_t
=
using
vecload_type_t
=
std
::
conditional_t
<
std
::
is_same
<
T
,
at
::
Half
>::
value
,
int2
,
int4
>
;
std
::
conditional_t
<
std
::
is_same
<
T
,
at
::
Half
>::
value
,
int2
,
int4
>
;
using
kernel_block_t
=
tv
::
mp_list_c
<
int
,
64
,
32
,
16
>
;
using
kernel_block_t
=
tv
::
mp_list_c
<
int
,
64
,
32
,
16
>
;
tv
::
DispatchTorch
<
int_types_t
>
()(
int_dtype
,
[
&
](
auto
IndexValue
)
{
tv
::
DispatchTorch
<
int_types_t
>
()(
int_dtype
,
[
&
](
auto
IndexValue
)
{
using
Index
=
decltype
(
IndexValue
);
using
Index
=
TV_DECLTYPE
(
IndexValue
);
bool
notFound
=
true
;
bool
notFound
=
true
;
constexpr
int
vecloadFactor
=
sizeof
(
vecload_type_t
)
/
sizeof
(
T
);
constexpr
int
vecloadFactor
=
sizeof
(
vecload_type_t
)
/
sizeof
(
T
);
tv
::
mp_for_each
<
kernel_block_t
>
([
=
,
&
outFeatures
,
&
inFeatures
,
&
indicesIn
,
tv
::
mp_for_each
<
kernel_block_t
>
([
=
,
&
outFeatures
,
&
inFeatures
,
&
indicesIn
,
...
@@ -404,12 +404,12 @@ void maxpool_bwd_cuda(torch::Tensor outFeatures, torch::Tensor inFeatures,
...
@@ -404,12 +404,12 @@ void maxpool_bwd_cuda(torch::Tensor outFeatures, torch::Tensor inFeatures,
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
tv
::
DispatchTorch
<
float_types_t
>
()(
dtype
,
[
&
](
auto
TValue
)
{
tv
::
DispatchTorch
<
float_types_t
>
()(
dtype
,
[
&
](
auto
TValue
)
{
using
T
=
decltype
(
TValue
);
using
T
=
TV_DECLTYPE
(
TValue
);
using
vecload_type_t
=
using
vecload_type_t
=
std
::
conditional_t
<
std
::
is_same
<
T
,
at
::
Half
>::
value
,
int2
,
int4
>
;
std
::
conditional_t
<
std
::
is_same
<
T
,
at
::
Half
>::
value
,
int2
,
int4
>
;
using
kernel_block_t
=
tv
::
mp_list_c
<
int
,
64
,
32
,
16
>
;
using
kernel_block_t
=
tv
::
mp_list_c
<
int
,
64
,
32
,
16
>
;
tv
::
DispatchTorch
<
int_types_t
>
()(
int_dtype
,
[
&
](
auto
IndexValue
)
{
tv
::
DispatchTorch
<
int_types_t
>
()(
int_dtype
,
[
&
](
auto
IndexValue
)
{
using
Index
=
decltype
(
IndexValue
);
using
Index
=
TV_DECLTYPE
(
IndexValue
);
bool
notFound
=
true
;
bool
notFound
=
true
;
constexpr
int
vecloadFactor
=
sizeof
(
vecload_type_t
)
/
sizeof
(
T
);
constexpr
int
vecloadFactor
=
sizeof
(
vecload_type_t
)
/
sizeof
(
T
);
tv
::
mp_for_each
<
kernel_block_t
>
([
=
,
&
outFeatures
,
&
inFeatures
,
&
dout
,
tv
::
mp_for_each
<
kernel_block_t
>
([
=
,
&
outFeatures
,
&
inFeatures
,
&
dout
,
...
...
src/spconv/reordering.cc
View file @
7f91c408
...
@@ -26,9 +26,9 @@ void sparse_gather_cpu(torch::Tensor buffer, torch::Tensor features,
...
@@ -26,9 +26,9 @@ void sparse_gather_cpu(torch::Tensor buffer, torch::Tensor features,
auto
dtype
=
features
.
scalar_type
();
auto
dtype
=
features
.
scalar_type
();
auto
int_dtype
=
indices
.
scalar_type
();
auto
int_dtype
=
indices
.
scalar_type
();
tv
::
DispatchTorch
<
float_types_t
>
()(
dtype
,
[
&
](
auto
TValue
)
{
tv
::
DispatchTorch
<
float_types_t
>
()(
dtype
,
[
&
](
auto
TValue
)
{
using
T
=
decltype
(
TValue
);
using
T
=
TV_DECLTYPE
(
TValue
);
tv
::
DispatchTorch
<
int_types_t
>
()(
int_dtype
,
[
&
](
auto
IndexValue
)
{
tv
::
DispatchTorch
<
int_types_t
>
()(
int_dtype
,
[
&
](
auto
IndexValue
)
{
using
Index
=
decltype
(
IndexValue
);
using
Index
=
TV_DECLTYPE
(
IndexValue
);
Index
*
indices_data
=
indices
.
data_ptr
<
Index
>
();
Index
*
indices_data
=
indices
.
data_ptr
<
Index
>
();
T
*
buffer_data
=
buffer
.
data_ptr
<
T
>
();
T
*
buffer_data
=
buffer
.
data_ptr
<
T
>
();
const
T
*
features_data
=
features
.
data_ptr
<
T
>
();
const
T
*
features_data
=
features
.
data_ptr
<
T
>
();
...
@@ -50,9 +50,9 @@ void sparse_scatter_add_cpu(torch::Tensor buffer, torch::Tensor outFeatures,
...
@@ -50,9 +50,9 @@ void sparse_scatter_add_cpu(torch::Tensor buffer, torch::Tensor outFeatures,
auto
int_dtype
=
indices
.
scalar_type
();
auto
int_dtype
=
indices
.
scalar_type
();
tv
::
DispatchTorch
<
float_types_t
>
()(
dtype
,
[
&
](
auto
TValue
)
{
tv
::
DispatchTorch
<
float_types_t
>
()(
dtype
,
[
&
](
auto
TValue
)
{
using
T
=
decltype
(
TValue
);
using
T
=
TV_DECLTYPE
(
TValue
);
tv
::
DispatchTorch
<
int_types_t
>
()(
int_dtype
,
[
&
](
auto
IndexValue
)
{
tv
::
DispatchTorch
<
int_types_t
>
()(
int_dtype
,
[
&
](
auto
IndexValue
)
{
using
Index
=
decltype
(
IndexValue
);
using
Index
=
TV_DECLTYPE
(
IndexValue
);
Index
*
indices_data
=
indices
.
data_ptr
<
Index
>
();
Index
*
indices_data
=
indices
.
data_ptr
<
Index
>
();
const
T
*
buffer_data
=
buffer
.
data_ptr
<
T
>
();
const
T
*
buffer_data
=
buffer
.
data_ptr
<
T
>
();
T
*
features_data
=
outFeatures
.
data_ptr
<
T
>
();
T
*
features_data
=
outFeatures
.
data_ptr
<
T
>
();
...
...
src/spconv/reordering.cu
View file @
7f91c408
...
@@ -51,10 +51,10 @@ void sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features,
...
@@ -51,10 +51,10 @@ void sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features,
auto
dtype
=
features
.
scalar_type
();
auto
dtype
=
features
.
scalar_type
();
auto
inds_dtype
=
indices
.
scalar_type
();
auto
inds_dtype
=
indices
.
scalar_type
();
tv
::
DispatchTorch
<
float_types_t
>
()(
dtype
,
[
&
](
auto
TValue
)
{
tv
::
DispatchTorch
<
float_types_t
>
()(
dtype
,
[
&
](
auto
TValue
)
{
using
T
=
decltype
(
TValue
);
using
T
=
TV_DECLTYPE
(
TValue
);
using
vecload_type_t
=
typename
half_vec_sadd
<
T
>::
type
;
using
vecload_type_t
=
typename
half_vec_sadd
<
T
>::
type
;
tv
::
DispatchTorch
<
int_types_t
>
()(
inds_dtype
,
[
&
](
auto
IndexValue
)
{
tv
::
DispatchTorch
<
int_types_t
>
()(
inds_dtype
,
[
&
](
auto
IndexValue
)
{
using
Index
=
decltype
(
IndexValue
);
using
Index
=
TV_DECLTYPE
(
IndexValue
);
bool
notFound
=
true
;
bool
notFound
=
true
;
constexpr
int
vecloadFactor
=
sizeof
(
vecload_type_t
)
/
sizeof
(
T
);
constexpr
int
vecloadFactor
=
sizeof
(
vecload_type_t
)
/
sizeof
(
T
);
...
@@ -140,10 +140,10 @@ void sparse_scatter_add_cuda(torch::Tensor buffer, torch::Tensor outFeatures,
...
@@ -140,10 +140,10 @@ void sparse_scatter_add_cuda(torch::Tensor buffer, torch::Tensor outFeatures,
auto
inds_dtype
=
indices
.
scalar_type
();
auto
inds_dtype
=
indices
.
scalar_type
();
tv
::
DispatchTorch
<
float_types_t
>
()(
dtype
,
[
&
](
auto
TValue
)
{
tv
::
DispatchTorch
<
float_types_t
>
()(
dtype
,
[
&
](
auto
TValue
)
{
using
T
=
decltype
(
TValue
);
using
T
=
TV_DECLTYPE
(
TValue
);
using
vecload_type_t
=
typename
half_vec_sadd
<
T
>::
type
;
using
vecload_type_t
=
typename
half_vec_sadd
<
T
>::
type
;
tv
::
DispatchTorch
<
int_types_t
>
()(
inds_dtype
,
[
&
](
auto
IndexValue
)
{
tv
::
DispatchTorch
<
int_types_t
>
()(
inds_dtype
,
[
&
](
auto
IndexValue
)
{
using
Index
=
decltype
(
IndexValue
);
using
Index
=
TV_DECLTYPE
(
IndexValue
);
bool
notFound
=
true
;
bool
notFound
=
true
;
constexpr
int
vecloadFactor
=
constexpr
int
vecloadFactor
=
sizeof
(
vecload_type_t
)
/
sizeof
(
T
);
// important for half.
sizeof
(
vecload_type_t
)
/
sizeof
(
T
);
// important for half.
...
@@ -235,10 +235,10 @@ void batch_sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features,
...
@@ -235,10 +235,10 @@ void batch_sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features,
int
inds_stride
=
indices
.
size
(
1
);
int
inds_stride
=
indices
.
size
(
1
);
int
feature_stride
=
buffer
.
size
(
1
);
int
feature_stride
=
buffer
.
size
(
1
);
tv
::
DispatchTorch
<
float_types_t
>
()(
dtype
,
[
&
](
auto
TValue
)
{
tv
::
DispatchTorch
<
float_types_t
>
()(
dtype
,
[
&
](
auto
TValue
)
{
using
T
=
decltype
(
TValue
);
using
T
=
TV_DECLTYPE
(
TValue
);
using
vecload_type_t
=
typename
half_vec
<
T
>::
type
;
using
vecload_type_t
=
typename
half_vec
<
T
>::
type
;
tv
::
DispatchTorch
<
int_types_t
>
()(
inds_dtype
,
[
&
](
auto
IndexValue
)
{
tv
::
DispatchTorch
<
int_types_t
>
()(
inds_dtype
,
[
&
](
auto
IndexValue
)
{
using
Index
=
decltype
(
IndexValue
);
using
Index
=
TV_DECLTYPE
(
IndexValue
);
bool
notFound
=
true
;
bool
notFound
=
true
;
constexpr
int
vecloadFactor
=
sizeof
(
vecload_type_t
)
/
sizeof
(
T
);
constexpr
int
vecloadFactor
=
sizeof
(
vecload_type_t
)
/
sizeof
(
T
);
tv
::
mp_for_each
<
kernel_block_t
>
(
tv
::
mp_for_each
<
kernel_block_t
>
(
...
@@ -308,10 +308,10 @@ void batch_sparse_scatter_add_cuda(torch::Tensor buffer,
...
@@ -308,10 +308,10 @@ void batch_sparse_scatter_add_cuda(torch::Tensor buffer,
int
feature_stride
=
buffer
.
size
(
1
);
int
feature_stride
=
buffer
.
size
(
1
);
tv
::
DispatchTorch
<
float_types_t
>
()(
dtype
,
[
&
](
auto
TValue
)
{
tv
::
DispatchTorch
<
float_types_t
>
()(
dtype
,
[
&
](
auto
TValue
)
{
using
T
=
decltype
(
TValue
);
using
T
=
TV_DECLTYPE
(
TValue
);
using
vecload_type_t
=
typename
half_vec_sadd
<
T
>::
type
;
using
vecload_type_t
=
typename
half_vec_sadd
<
T
>::
type
;
tv
::
DispatchTorch
<
int_types_t
>
()(
inds_dtype
,
[
&
](
auto
IndexValue
)
{
tv
::
DispatchTorch
<
int_types_t
>
()(
inds_dtype
,
[
&
](
auto
IndexValue
)
{
using
Index
=
decltype
(
IndexValue
);
using
Index
=
TV_DECLTYPE
(
IndexValue
);
bool
notFound
=
true
;
bool
notFound
=
true
;
constexpr
int
vecloadFactor
=
1
;
// important for half.
constexpr
int
vecloadFactor
=
1
;
// important for half.
...
...
cutlass
@
c2b80ad4
Subproject commit c2b80ad4e4f8b60a65500bd04c8fecddff2ba355
mp11
@
29764aad
Subproject commit 29764aad4881fde809af6a025c12012e47a55515
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment