"git@developer.sourcefind.cn:OpenDAS/ktransformers.git" did not exist on "c05ebb74b1a04376cc4f7863a66efec1457bdede"
Commit 910d6a98 authored by sangwzh's avatar sangwzh
Browse files

update atomicAdd and csr2coo.hip

parent 8f9dcabf
...@@ -169,7 +169,7 @@ static __host__ __device__ __forceinline__ unsigned short int atomicCASshort( / ...@@ -169,7 +169,7 @@ static __host__ __device__ __forceinline__ unsigned short int atomicCASshort( /
return Cast<dtype>::Decode(old); \ return Cast<dtype>::Decode(old); \
} }
#define DEFINE_ATOMIC_16BIT_BF(NAME, dtype) \ #define DEFINE_ATOMIC_16BIT_MAX(NAME, dtype) \
template <> \ template <> \
__device__ __forceinline__ dtype Atomic##NAME<dtype>( \ __device__ __forceinline__ dtype Atomic##NAME<dtype>( \
dtype * addr, dtype val) { \ dtype * addr, dtype val) { \
...@@ -181,12 +181,12 @@ static __host__ __device__ __forceinline__ unsigned short int atomicCASshort( / ...@@ -181,12 +181,12 @@ static __host__ __device__ __forceinline__ unsigned short int atomicCASshort( /
assumed = old; \ assumed = old; \
old = atomicCASshort( \ old = atomicCASshort( \
addr_as_ui, assumed, \ addr_as_ui, assumed, \
Cast<dtype>::Encode(max((double)val, (double)dtype(old)))); \ Cast<dtype>::Encode(dtype(max((float)val, (float)dtype(old))))); \
} while (assumed != old); \ } while (assumed != old); \
return Cast<dtype>::Decode(old); \ return Cast<dtype>::Decode(old); \
} }
#define DEFINE_ATOMIC_16BIT_Min(NAME, dtype) \ #define DEFINE_ATOMIC_16BIT_MIN(NAME, dtype) \
template <> \ template <> \
__device__ __forceinline__ dtype Atomic##NAME<dtype>( \ __device__ __forceinline__ dtype Atomic##NAME<dtype>( \
dtype * addr, dtype val) { \ dtype * addr, dtype val) { \
...@@ -198,24 +198,25 @@ static __host__ __device__ __forceinline__ unsigned short int atomicCASshort( / ...@@ -198,24 +198,25 @@ static __host__ __device__ __forceinline__ unsigned short int atomicCASshort( /
assumed = old; \ assumed = old; \
old = atomicCASshort( \ old = atomicCASshort( \
addr_as_ui, assumed, \ addr_as_ui, assumed, \
Cast<dtype>::Encode(min(val, dtype(old)))); \ Cast<dtype>::Encode(dtype(min((float)val,(float)old)))); \
} while (assumed != old); \ } while (assumed != old); \
return Cast<dtype>::Decode(old); \ return Cast<dtype>::Decode(old); \
} }
#define OP(a, b) max((double)a, (double)b) #define OP(a, b) max(a, b)
DEFINE_ATOMIC(Max) DEFINE_ATOMIC(Max)
DEFINE_ATOMIC_16BIT(Max, half) DEFINE_ATOMIC_16BIT_MAX(Max, half)
#if BF16_ENABLED #if BF16_ENABLED
DEFINE_ATOMIC_16BIT_BF(Max, __hip_bfloat16) #define OP_BF(a, b) max_bf((float)a, (float)b)
DEFINE_ATOMIC_16BIT_MAX(Max, __hip_bfloat16)
#endif // BF16_ENABLED #endif // BF16_ENABLED
#undef OP #undef OP
#define OP(a, b) min((double)a, (double)b) #define OP(a, b) min(a, b)
DEFINE_ATOMIC(Min) DEFINE_ATOMIC(Min)
DEFINE_ATOMIC_16BIT(Min, half) DEFINE_ATOMIC_16BIT_MIN(Min, half)
#if BF16_ENABLED #if BF16_ENABLED
DEFINE_ATOMIC_16BIT_BF(Min, __hip_bfloat16) DEFINE_ATOMIC_16BIT_MIN(Min, __hip_bfloat16)
#endif // BF16_ENABLED #endif // BF16_ENABLED
#undef OP #undef OP
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <thrust/iterator/constant_iterator.h> #include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/counting_iterator.h> #include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h> #include <thrust/iterator/transform_iterator.h>
#include <hipcub/backend/rocprim/device/device_copy.hpp>
#include <hipcub/hipcub.hpp> #include <hipcub/hipcub.hpp>
...@@ -103,7 +104,7 @@ __global__ void _RepeatKernel( ...@@ -103,7 +104,7 @@ __global__ void _RepeatKernel(
} }
#if 0 #if 1
template <> template <>
COOMatrix CSRToCOO<kDGLCUDA, int64_t>(CSRMatrix csr) { COOMatrix CSRToCOO<kDGLCUDA, int64_t>(CSRMatrix csr) {
const auto& ctx = csr.indptr->ctx; const auto& ctx = csr.indptr->ctx;
...@@ -126,14 +127,14 @@ COOMatrix CSRToCOO<kDGLCUDA, int64_t>(CSRMatrix csr) { ...@@ -126,14 +127,14 @@ COOMatrix CSRToCOO<kDGLCUDA, int64_t>(CSRMatrix csr) {
constexpr int64_t max_copy_at_once = std::numeric_limits<int32_t>::max(); constexpr int64_t max_copy_at_once = std::numeric_limits<int32_t>::max();
for (int64_t i = 0; i < csr.num_rows; i += max_copy_at_once) { for (int64_t i = 0; i < csr.num_rows; i += max_copy_at_once) {
std::size_t temp_storage_bytes = 0; std::size_t temp_storage_bytes = 0;
CUDA_CALL(cub::DeviceCopy::Batched( CUDA_CALL(hipcub::DeviceCopy::Batched(
nullptr, temp_storage_bytes, input_buffer + i, output_buffer + i, nullptr, temp_storage_bytes, input_buffer + i, output_buffer + i,
buffer_sizes + i, ::min(csr.num_rows - i, max_copy_at_once), buffer_sizes + i, ::min(csr.num_rows - i, max_copy_at_once),
stream)); stream));
auto temp = allocator.alloc_unique<char>(temp_storage_bytes); auto temp = allocator.alloc_unique<char>(temp_storage_bytes);
CUDA_CALL(cub::DeviceCopy::Batched( CUDA_CALL(hipcub::DeviceCopy::Batched(
temp.get(), temp_storage_bytes, input_buffer + i, output_buffer + i, temp.get(), temp_storage_bytes, input_buffer + i, output_buffer + i,
buffer_sizes + i, ::min(csr.num_rows - i, max_copy_at_once), buffer_sizes + i, ::min(csr.num_rows - i, max_copy_at_once),
stream)); stream));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment