Commit 42d92ee8 authored by Yan Yan's avatar Yan Yan
Browse files

fix #226: workaround for cuda 9.0/9.1

parent abf0acf3
if(WIN32)
add_library(cuhash SHARED hash_functions.cu hash_table.cpp hash_table.cu hash_functions.cpp)
else()
add_library(cuhash STATIC hash_functions.cu hash_table.cpp hash_table.cu hash_functions.cpp)
add_library(cuhash SHARED hash_functions.cu hash_table.cpp hash_table.cu hash_functions.cpp)
endif()
target_include_directories(cuhash PRIVATE ${ALL_INCLUDE} )
set_property(TARGET cuhash PROPERTY CUDA_STANDARD 14)
......
......@@ -46,8 +46,5 @@ cublasStatus_t cublasTgemm(cublasHandle_t handle, cublasOperation_t transa,
beta, C, ldc);
}
template <> inline __half constant_scalar(float data) {
return __float2half(data);
}
} // namespace spconv
\ No newline at end of file
......@@ -29,12 +29,17 @@ namespace spconv {
using float_types_t = tv::mp_list<float, double, at::Half>;
using int_types_t = tv::mp_list<int32_t, int64_t>;
template <typename T>
using half_vec_t =
std::conditional_t<std::is_same<T, at::Half>::value, int4, int4>;
struct half_vec{
using type = typename std::conditional_t<std::is_same<T, at::Half>::value, int4, int4>;
};
template <typename T>
using half_vec_sadd_t =
std::conditional_t<std::is_same<T, at::Half>::value, int4, int4>;
struct half_vec_sadd{
using type = typename std::conditional_t<std::is_same<T, at::Half>::value, int4, int4>;
};
using kernel_block_t = tv::mp_list_c<int, 64, 32, 16>;
void sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features,
......@@ -47,7 +52,7 @@ void sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features,
auto inds_dtype = indices.scalar_type();
tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) {
using T = decltype(TValue);
using vecload_type_t = half_vec_t<T>;
using vecload_type_t = typename half_vec_sadd<T>::type;
tv::DispatchTorch<int_types_t>()(inds_dtype, [&](auto IndexValue) {
using Index = decltype(IndexValue);
bool notFound = true;
......@@ -136,7 +141,7 @@ void sparse_scatter_add_cuda(torch::Tensor buffer, torch::Tensor outFeatures,
tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) {
using T = decltype(TValue);
using vecload_type_t = half_vec_sadd_t<T>;
using vecload_type_t = typename half_vec_sadd<T>::type;
tv::DispatchTorch<int_types_t>()(inds_dtype, [&](auto IndexValue) {
using Index = decltype(IndexValue);
bool notFound = true;
......@@ -231,7 +236,7 @@ void batch_sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features,
int feature_stride = buffer.size(1);
tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) {
using T = decltype(TValue);
using vecload_type_t = half_vec_t<T>;
using vecload_type_t = typename half_vec<T>::type;
tv::DispatchTorch<int_types_t>()(inds_dtype, [&](auto IndexValue) {
using Index = decltype(IndexValue);
bool notFound = true;
......@@ -304,7 +309,7 @@ void batch_sparse_scatter_add_cuda(torch::Tensor buffer,
tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) {
using T = decltype(TValue);
using vecload_type_t = half_vec_sadd_t<T>;
using vecload_type_t = typename half_vec_sadd<T>::type;
tv::DispatchTorch<int_types_t>()(inds_dtype, [&](auto IndexValue) {
using Index = decltype(IndexValue);
bool notFound = true;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment