Commit 9a9f0982 authored by zhuyue's avatar zhuyue
Browse files

Issue/654 - Update CUB API usage for CUDA 12.9+ compatibility

parent 874cc65b
......@@ -81,7 +81,7 @@ __device__ void blockLayernormKernel(T *output, T const *input, T const *weight,
}
__shared__ float sigma2;
float sigma2_block = BlockReduce(temp_storage).Reduce(sigma2_partial, cub::Sum());
float sigma2_block = BlockReduce(temp_storage).Sum(sigma2_partial);
if (threadIdx.x == 0) {
float sigma_tmp = sqrt(sigma2_block * __fdividef(1.0F, dimsize) + eps);
sigma2 = __fdividef(1.0F, sigma_tmp);
......
......@@ -17,7 +17,11 @@ __device__ void blockLPNormKernel(
local_max = max(local_max, fabsf((float)input[tid + ind * stride]));
}
__shared__ float global_max;
#if CUDART_VERSION >= 12090
float max_block = BlockReduce(temp_storage).Reduce(local_max, ::cuda::maximum());
#else
float max_block = BlockReduce(temp_storage).Reduce(local_max, cub::Max());
#endif
if (threadIdx.x == 0) { // must set threadIdx.x = 0 write the output to memory
global_max = max_block;
}
......@@ -30,7 +34,7 @@ __device__ void blockLPNormKernel(
}
__shared__ float p_total;
float p_block = BlockReduce(temp_storage).Reduce(p_partial, cub::Sum());
float p_block = BlockReduce(temp_storage).Sum(p_partial);
if (threadIdx.x == 0) { // must set threadIdx.x = 0 write the output to memory
p_total = powf(p_block, 1.0f / p);
}
......@@ -69,7 +73,11 @@ __device__ void blockLPNormStridesKernel(
local_max = max(local_max, fabsf((float)input[ind_i + ind]));
}
__shared__ float global_max;
#if CUDART_VERSION >= 12090
float max_block = BlockReduce(temp_storage).Reduce(local_max, ::cuda::maximum());
#else
float max_block = BlockReduce(temp_storage).Reduce(local_max, cub::Max());
#endif
if (threadIdx.x == 0) { // must set threadIdx.x = 0 write the output to memory
global_max = max_block;
}
......@@ -82,7 +90,7 @@ __device__ void blockLPNormStridesKernel(
}
__shared__ float p_total;
float p_block = BlockReduce(temp_storage).Reduce(p_partial, cub::Sum());
float p_block = BlockReduce(temp_storage).Sum(p_partial);
if (threadIdx.x == 0) { // must set threadIdx.x = 0 write the output to memory
p_total = powf(p_block, 1.0f / p);
}
......
......@@ -2,8 +2,6 @@
#define __TANH_CUDA_H__
#include <cmath>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
namespace op::tanh::cuda {
typedef struct TanhOp {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment