Commit 10572e55 authored by zhuyue's avatar zhuyue
Browse files

Issue/654 - Update CUB API usage for CUDA 12.9+ compatibility

parent d18b77a0
......@@ -81,7 +81,7 @@ __device__ void blockLayernormKernel(T *output, T const *input, T const *weight,
}
__shared__ float sigma2;
float sigma2_block = BlockReduce(temp_storage).Reduce(sigma2_partial, cub::Sum());
float sigma2_block = BlockReduce(temp_storage).Sum(sigma2_partial);
if (threadIdx.x == 0) {
float sigma_tmp = sqrt(sigma2_block * __fdividef(1.0F, dimsize) + eps);
sigma2 = __fdividef(1.0F, sigma_tmp);
......
......@@ -17,7 +17,11 @@ __device__ void blockLPNormKernel(
local_max = max(local_max, fabsf((float)input[tid + ind * stride]));
}
__shared__ float global_max;
#if CUDART_VERSION >= 12090
float max_block = BlockReduce(temp_storage).Reduce(local_max, ::cuda::maximum());
#else
float max_block = BlockReduce(temp_storage).Reduce(local_max, cub::Max());
#endif
if (threadIdx.x == 0) { // must set threadIdx.x = 0 write the output to memory
global_max = max_block;
}
......@@ -30,7 +34,7 @@ __device__ void blockLPNormKernel(
}
__shared__ float p_total;
float p_block = BlockReduce(temp_storage).Reduce(p_partial, cub::Sum());
float p_block = BlockReduce(temp_storage).Sum(p_partial);
if (threadIdx.x == 0) { // must set threadIdx.x = 0 write the output to memory
p_total = powf(p_block, 1.0f / p);
}
......@@ -69,7 +73,11 @@ __device__ void blockLPNormStridesKernel(
local_max = max(local_max, fabsf((float)input[ind_i + ind]));
}
__shared__ float global_max;
#if CUDART_VERSION >= 12090
float max_block = BlockReduce(temp_storage).Reduce(local_max, ::cuda::maximum());
#else
float max_block = BlockReduce(temp_storage).Reduce(local_max, cub::Max());
#endif
if (threadIdx.x == 0) { // must set threadIdx.x = 0 write the output to memory
global_max = max_block;
}
......@@ -82,7 +90,7 @@ __device__ void blockLPNormStridesKernel(
}
__shared__ float p_total;
float p_block = BlockReduce(temp_storage).Reduce(p_partial, cub::Sum());
float p_block = BlockReduce(temp_storage).Sum(p_partial);
if (threadIdx.x == 0) { // must set threadIdx.x = 0 write the output to memory
p_total = powf(p_block, 1.0f / p);
}
......
......@@ -40,9 +40,9 @@ __C infiniStatus_t infiniopCreateTanhDescriptor(
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
// #ifdef ENABLE_METAX_API
// CREATE(INFINI_DEVICE_METAX, metax);
// #endif
// #ifdef ENABLE_METAX_API
// CREATE(INFINI_DEVICE_METAX, metax);
// #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -71,9 +71,9 @@ __C infiniStatus_t infiniopGetTanhWorkspaceSize(infiniopTanhDescriptor_t desc, s
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
// #ifdef ENABLE_METAX_API
// GET(INFINI_DEVICE_METAX, metax);
// #endif
// #ifdef ENABLE_METAX_API
// GET(INFINI_DEVICE_METAX, metax);
// #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -109,9 +109,9 @@ __C infiniStatus_t infiniopTanh(
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
// #ifdef ENABLE_METAX_API
// CALCULATE(INFINI_DEVICE_METAX, metax);
// #endif
// #ifdef ENABLE_METAX_API
// CALCULATE(INFINI_DEVICE_METAX, metax);
// #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -142,9 +142,9 @@ infiniopDestroyTanhDescriptor(infiniopTanhDescriptor_t desc) {
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
// #ifdef ENABLE_METAX_API
// DELETE(INFINI_DEVICE_METAX, metax);
// #endif
// #ifdef ENABLE_METAX_API
// DELETE(INFINI_DEVICE_METAX, metax);
// #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment