Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
42d92ee8
Commit
42d92ee8
authored
Sep 27, 2020
by
Yan Yan
Browse files
fix #226: workaround for cuda 9.0/9.1
parent
abf0acf3
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
12 deletions
+14
-12
src/cuhash/CMakeLists.txt
src/cuhash/CMakeLists.txt
+1
-1
src/spconv/cublas_gemm.cc
src/spconv/cublas_gemm.cc
+0
-3
src/spconv/reordering.cu
src/spconv/reordering.cu
+13
-8
No files found.
src/cuhash/CMakeLists.txt
View file @
42d92ee8
if
(
WIN32
)
add_library
(
cuhash SHARED hash_functions.cu hash_table.cpp hash_table.cu hash_functions.cpp
)
else
()
add_library
(
cuhash S
TATIC
hash_functions.cu hash_table.cpp hash_table.cu hash_functions.cpp
)
add_library
(
cuhash S
HARED
hash_functions.cu hash_table.cpp hash_table.cu hash_functions.cpp
)
endif
()
target_include_directories
(
cuhash PRIVATE
${
ALL_INCLUDE
}
)
set_property
(
TARGET cuhash PROPERTY CUDA_STANDARD 14
)
...
...
src/spconv/cublas_gemm.cc
View file @
42d92ee8
...
...
@@ -46,8 +46,5 @@ cublasStatus_t cublasTgemm(cublasHandle_t handle, cublasOperation_t transa,
beta
,
C
,
ldc
);
}
template
<
>
inline
__half
constant_scalar
(
float
data
)
{
return
__float2half
(
data
);
}
}
// namespace spconv
\ No newline at end of file
src/spconv/reordering.cu
View file @
42d92ee8
...
...
@@ -29,12 +29,17 @@ namespace spconv {
using
float_types_t
=
tv
::
mp_list
<
float
,
double
,
at
::
Half
>
;
using
int_types_t
=
tv
::
mp_list
<
int32_t
,
int64_t
>
;
template
<
typename
T
>
using
half_vec_t
=
std
::
conditional_t
<
std
::
is_same
<
T
,
at
::
Half
>::
value
,
int4
,
int4
>
;
struct
half_vec
{
using
type
=
typename
std
::
conditional_t
<
std
::
is_same
<
T
,
at
::
Half
>::
value
,
int4
,
int4
>
;
};
template
<
typename
T
>
using
half_vec_sadd_t
=
std
::
conditional_t
<
std
::
is_same
<
T
,
at
::
Half
>::
value
,
int4
,
int4
>
;
struct
half_vec_sadd
{
using
type
=
typename
std
::
conditional_t
<
std
::
is_same
<
T
,
at
::
Half
>::
value
,
int4
,
int4
>
;
};
using
kernel_block_t
=
tv
::
mp_list_c
<
int
,
64
,
32
,
16
>
;
void
sparse_gather_cuda
(
torch
::
Tensor
buffer
,
torch
::
Tensor
features
,
...
...
@@ -47,7 +52,7 @@ void sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features,
auto
inds_dtype
=
indices
.
scalar_type
();
tv
::
DispatchTorch
<
float_types_t
>
()(
dtype
,
[
&
](
auto
TValue
)
{
using
T
=
decltype
(
TValue
);
using
vecload_type_t
=
half_vec_t
<
T
>
;
using
vecload_type_t
=
typename
half_vec_sadd
<
T
>::
type
;
tv
::
DispatchTorch
<
int_types_t
>
()(
inds_dtype
,
[
&
](
auto
IndexValue
)
{
using
Index
=
decltype
(
IndexValue
);
bool
notFound
=
true
;
...
...
@@ -136,7 +141,7 @@ void sparse_scatter_add_cuda(torch::Tensor buffer, torch::Tensor outFeatures,
tv
::
DispatchTorch
<
float_types_t
>
()(
dtype
,
[
&
](
auto
TValue
)
{
using
T
=
decltype
(
TValue
);
using
vecload_type_t
=
half_vec_sadd
_t
<
T
>
;
using
vecload_type_t
=
typename
half_vec_sadd
<
T
>
::
type
;
tv
::
DispatchTorch
<
int_types_t
>
()(
inds_dtype
,
[
&
](
auto
IndexValue
)
{
using
Index
=
decltype
(
IndexValue
);
bool
notFound
=
true
;
...
...
@@ -231,7 +236,7 @@ void batch_sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features,
int
feature_stride
=
buffer
.
size
(
1
);
tv
::
DispatchTorch
<
float_types_t
>
()(
dtype
,
[
&
](
auto
TValue
)
{
using
T
=
decltype
(
TValue
);
using
vecload_type_t
=
half_vec
_t
<
T
>
;
using
vecload_type_t
=
typename
half_vec
<
T
>
::
type
;
tv
::
DispatchTorch
<
int_types_t
>
()(
inds_dtype
,
[
&
](
auto
IndexValue
)
{
using
Index
=
decltype
(
IndexValue
);
bool
notFound
=
true
;
...
...
@@ -304,7 +309,7 @@ void batch_sparse_scatter_add_cuda(torch::Tensor buffer,
tv
::
DispatchTorch
<
float_types_t
>
()(
dtype
,
[
&
](
auto
TValue
)
{
using
T
=
decltype
(
TValue
);
using
vecload_type_t
=
half_vec_sadd
_t
<
T
>
;
using
vecload_type_t
=
typename
half_vec_sadd
<
T
>
::
type
;
tv
::
DispatchTorch
<
int_types_t
>
()(
inds_dtype
,
[
&
](
auto
IndexValue
)
{
using
Index
=
decltype
(
IndexValue
);
bool
notFound
=
true
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment