Commit fcb4ebeb authored by qinyiqun's avatar qinyiqun
Browse files

issue/32: 提取类cuda kernel的公共部分,减少冗余代码

parent 5c94d4e1
#ifndef __INFINIOP_CUDA_COMMON_CUH__
#define __INFINIOP_CUDA_COMMON_CUH__
#include "../../reduce/cuda/reduce.cuh"
#include "cuda_handle.cuh"
#include "infinicore.h"
#ifdef ENABLE_SUGON_CUDA_API
#define INFINIOP_CUDA_KERNEL __launch_bounds__(512) __global__ void
#else
#define INFINIOP_CUDA_KERNEL __global__ void
#endif
// Posible maximum number of threads per block for CUDA architectures
// Used for picking correct kernel launch configuration
#define CUDA_BLOCK_SIZE_1024 1024
#define CUDA_BLOCK_SIZE_512 512
namespace device::cuda {
cudnnDataType_t getCudnnDtype(infiniDtype_t dt);
// return the memory offset of original tensor, given the flattened index of broadcasted tensor
__forceinline__ __device__ __host__ size_t
indexToReducedOffset(
size_t flat_index,
size_t ndim,
const ptrdiff_t *broadcasted_strides,
const ptrdiff_t *target_strides) {
size_t res = 0;
for (size_t i = 0; i < ndim; ++i) {
res += flat_index / broadcasted_strides[i] * target_strides[i];
flat_index %= broadcasted_strides[i];
}
return res;
}
// get the memory offset of the given element in a tensor given its flat index
__forceinline__ __device__ __host__ size_t
indexToOffset(
size_t flat_index,
size_t ndim,
const size_t *shape,
const ptrdiff_t *strides) {
size_t res = 0;
for (size_t i = ndim; i-- > 0;) {
res += (flat_index % shape[i]) * strides[i];
flat_index /= shape[i];
}
return res;
}
} // namespace device::cuda
#endif // __INFINIOP_CUDA_COMMON_CUH__
#ifdef ENABLE_SUGON_CUDA_API
#define INFINIOP_CUDA_KERNEL __launch_bounds__(512) __global__ void
#else
#define INFINIOP_CUDA_KERNEL __global__ void
#endif
// Posible maximum number of threads per block for CUDA architectures
// Used for picking correct kernel launch configuration
#define CUDA_BLOCK_SIZE_1024 1024
#define CUDA_BLOCK_SIZE_512 512
// return the memory offset of original tensor, given the flattened index of broadcasted tensor
__forceinline__ __device__ __host__ size_t
indexToReducedOffset(
size_t flat_index,
size_t ndim,
const ptrdiff_t *broadcasted_strides,
const ptrdiff_t *target_strides) {
size_t res = 0;
for (size_t i = 0; i < ndim; ++i) {
res += flat_index / broadcasted_strides[i] * target_strides[i];
flat_index %= broadcasted_strides[i];
}
return res;
}
// get the memory offset of the given element in a tensor given its flat index
__forceinline__ __device__ __host__ size_t
indexToOffset(
size_t flat_index,
size_t ndim,
const size_t *shape,
const ptrdiff_t *strides) {
size_t res = 0;
for (size_t i = ndim; i-- > 0;) {
res += (flat_index % shape[i]) * strides[i];
flat_index /= shape[i];
}
return res;
}
#ifdef ENABLE_CUDA_API
#include <cuda_fp16.h>
__forceinline__ __device__ float
exp_(const float val) {
return expf(val);
}
__forceinline__ __device__ long double
exp_(const long double val) {
return expl(val);
}
__forceinline__ __device__ double
exp_(const double val) {
return exp(val);
}
__forceinline__ __device__ __half
exp_(const __half x) {
return hexp(x);
}
#endif
......@@ -10,9 +10,6 @@
#define CHECK_MUBLAS(API) CHECK_INTERNAL(API, MUBLAS_STATUS_SUCCESS)
#define CHECK_MUDNN(API) CHECK_INTERNAL((int)API, (int)::musa::dnn::Status::SUCCESS)
#define INFINIOP_MUSA_KERNEL __global__ void
#define MUSA_BLOCK_SIZE_1024 1024
namespace device::musa {
class Handle::Internal {
......
#include "causal_softmax_cuda.cuh"
#include "../../../devices/cuda/cuda_common.cuh"
#include "causal_softmax_cuda.cuh"
#include "causal_softmax_kernel.cuh"
namespace op::causal_softmax::cuda {
......
#ifndef __CAUSAL_SOFTMAX_KERNEL_CUH__
#define __CAUSAL_SOFTMAX_KERNEL_CUH__
#include "../../../devices/cuda/cuda_common.cuh"
#include "../../../devices/cuda/cuda_kernel_common.cuh"
#include "../../../reduce/cuda/reduce.cuh"
template <unsigned int BLOCK_SIZE, typename Tdata, typename Tcompute>
INFINIOP_CUDA_KERNEL causalSoftmax(
......@@ -31,7 +32,11 @@ INFINIOP_CUDA_KERNEL causalSoftmax(
// 2 | * * * ... * * * |
// height: 3 col_id->
if (width + blockIdx.x >= threadIdx.x + height) {
#ifdef ENABLE_CUDA_API
y[col] = exp_(x[col] - max_);
#else
y[col] = exp(x[col] - max_);
#endif
} else {
y[col] = Tdata(0);
}
......
#ifndef __RMS_NORM_CUDA_KERNEL_H__
#define __RMS_NORM_CUDA_KERNEL_H__
#include "../../../devices/cuda/cuda_common.cuh"
#include <cub/block/block_reduce.cuh>
#include "../../../devices/cuda/cuda_kernel_common.cuh"
#include "../../../reduce/cuda/reduce.cuh"
template <unsigned int BLOCK_SIZE, typename Tdata, typename Tweight, typename Tcompute>
INFINIOP_CUDA_KERNEL rmsnormBlock(
......
#ifndef __RMS_NORM_MUSA_KERNEL_H__
#define __RMS_NORM_MUSA_KERNEL_H__
#include "../../../reduce/musa/reduce.h"
template <unsigned int BLOCK_SIZE, typename Tdata, typename Tweight, typename Tcompute>
INFINIOP_MUSA_KERNEL rmsnormBlock(
Tdata *__restrict__ y,
ptrdiff_t stride_y,
const Tdata *__restrict__ x,
ptrdiff_t stride_x,
const Tweight *__restrict__ w,
size_t dim,
float epsilon) {
// Each block takes care of a row of continuous data of length dim
// Each thread deals with every block_size element in the row
auto y_ptr = y + blockIdx.x * stride_y;
auto x_ptr = x + blockIdx.x * stride_x;
auto w_ptr = w;
// Block-reduce sum of x^2
Tcompute ss = op::common_musa::reduce_op::sumSquared<BLOCK_SIZE, Tdata, Tcompute>(x_ptr, dim);
// Thread_0 computes RMS=1/sqrt(ss/dim+epsilon) and stores in shared memory
__shared__ Tcompute rms;
if (threadIdx.x == 0) {
rms = Tdata(rsqrtf(ss / Tcompute(dim) + epsilon));
}
__syncthreads();
for (size_t i = threadIdx.x; i < dim; i += BLOCK_SIZE) {
y_ptr[i] = Tdata(Tcompute(x_ptr[i]) * Tcompute(w_ptr[i]) * rms);
}
}
#endif
#ifndef __RMS_NORM_MUSA_H__
#define __RMS_NORM_MUSA_H__
#ifndef __RMS_NORM_MUSA_CUH__
#define __RMS_NORM_MUSA_CUH__
#include "../rms_norm.h"
......
#include "../../../devices/musa/common_musa.h"
#include "rms_norm_kernel.h"
#include "rms_norm_musa.h"
#include "../cuda/rms_norm_kernel.cuh"
#include "rms_norm_musa.cuh"
namespace op::rms_norm::musa {
......@@ -87,8 +87,8 @@ infiniStatus_t Descriptor::calculate(
auto musa_stream = reinterpret_cast<musaStream_t>(stream);
// launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == MUSA_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<MUSA_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
......
......@@ -12,7 +12,7 @@
#include "ascend/rms_norm_aclnn.h"
#endif
#ifdef ENABLE_MOORE_API
#include "musa/rms_norm_musa.h"
#include "musa/rms_norm_musa.cuh"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/rms_norm_kunlun.h"
......
#ifndef __INFINIOP_REDUCE_MUSA_H__
#define __INFINIOP_REDUCE_MUSA_H__
#include <cub/block/block_reduce.cuh>
namespace op::common_musa::reduce_op {
template <unsigned int BLOCK_SIZE, typename Tdata, typename Tcompute>
__device__ __forceinline__ Tcompute sumSquared(const Tdata *data_ptr, size_t count) {
Tcompute ss = 0;
// Each thread computes its partial sum
for (size_t i = threadIdx.x; i < count; i += BLOCK_SIZE) {
ss += Tcompute(data_ptr[i] * data_ptr[i]);
}
// Use CUB block-level reduction
using BlockReduce = cub::BlockReduce<Tcompute, BLOCK_SIZE>;
__shared__ typename BlockReduce::TempStorage temp_storage;
return BlockReduce(temp_storage).Sum(ss);
}
} // namespace op::common_musa::reduce_op
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment