Unverified Commit 4f8afafd authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #155 from InfiniTensor/issue/32

issue/32: 实现摩尔线程rms_norm算子
parents 2391ec99 fcb4ebeb
#ifndef __INFINIOP_CUDA_COMMON_CUH__
#define __INFINIOP_CUDA_COMMON_CUH__
#include "../../reduce/cuda/reduce.cuh"
#include "cuda_handle.cuh"
#include "infinicore.h"
#ifdef ENABLE_SUGON_CUDA_API
#define INFINIOP_CUDA_KERNEL __launch_bounds__(512) __global__ void
#else
#define INFINIOP_CUDA_KERNEL __global__ void
#endif
// Posible maximum number of threads per block for CUDA architectures
// Used for picking correct kernel launch configuration
#define CUDA_BLOCK_SIZE_1024 1024
#define CUDA_BLOCK_SIZE_512 512
namespace device::cuda {
cudnnDataType_t getCudnnDtype(infiniDtype_t dt);
// return the memory offset of original tensor, given the flattened index of broadcasted tensor
__forceinline__ __device__ __host__ size_t
indexToReducedOffset(
size_t flat_index,
size_t ndim,
const ptrdiff_t *broadcasted_strides,
const ptrdiff_t *target_strides) {
size_t res = 0;
for (size_t i = 0; i < ndim; ++i) {
res += flat_index / broadcasted_strides[i] * target_strides[i];
flat_index %= broadcasted_strides[i];
}
return res;
}
// get the memory offset of the given element in a tensor given its flat index
__forceinline__ __device__ __host__ size_t
indexToOffset(
size_t flat_index,
size_t ndim,
const size_t *shape,
const ptrdiff_t *strides) {
size_t res = 0;
for (size_t i = ndim; i-- > 0;) {
res += (flat_index % shape[i]) * strides[i];
flat_index /= shape[i];
}
return res;
}
} // namespace device::cuda
#endif // __INFINIOP_CUDA_COMMON_CUH__
#ifdef ENABLE_SUGON_CUDA_API
#define INFINIOP_CUDA_KERNEL __launch_bounds__(512) __global__ void
#else
#define INFINIOP_CUDA_KERNEL __global__ void
#endif
// Posible maximum number of threads per block for CUDA architectures
// Used for picking correct kernel launch configuration
#define CUDA_BLOCK_SIZE_1024 1024
#define CUDA_BLOCK_SIZE_512 512
// return the memory offset of original tensor, given the flattened index of broadcasted tensor
__forceinline__ __device__ __host__ size_t
indexToReducedOffset(
size_t flat_index,
size_t ndim,
const ptrdiff_t *broadcasted_strides,
const ptrdiff_t *target_strides) {
size_t res = 0;
for (size_t i = 0; i < ndim; ++i) {
res += flat_index / broadcasted_strides[i] * target_strides[i];
flat_index %= broadcasted_strides[i];
}
return res;
}
// get the memory offset of the given element in a tensor given its flat index
__forceinline__ __device__ __host__ size_t
indexToOffset(
size_t flat_index,
size_t ndim,
const size_t *shape,
const ptrdiff_t *strides) {
size_t res = 0;
for (size_t i = ndim; i-- > 0;) {
res += (flat_index % shape[i]) * strides[i];
flat_index /= shape[i];
}
return res;
}
#ifdef ENABLE_CUDA_API
#include <cuda_fp16.h>
__forceinline__ __device__ float
exp_(const float val) {
return expf(val);
}
__forceinline__ __device__ long double
exp_(const long double val) {
return expl(val);
}
__forceinline__ __device__ double
exp_(const double val) {
return exp(val);
}
__forceinline__ __device__ __half
exp_(const __half x) {
return hexp(x);
}
#endif
......@@ -16,12 +16,27 @@ class Handle::Internal {
Pool<std::unique_ptr<mublasHandle_t>> mublas_handles;
Pool<std::unique_ptr<::musa::dnn::Handle>> mudnn_handles;
int _warp_size,
_max_threads_per_block,
_block_size[3],
_grid_size[3];
template <typename T>
using Fn = std::function<infiniStatus_t(T)>;
public:
Internal(int);
infiniStatus_t useMublas(musaStream_t stream, const Fn<mublasHandle_t> &f) const;
infiniStatus_t useMudnn(musaStream_t stream, const Fn<::musa::dnn::Handle &> &f) const;
int warpSize() const;
int maxThreadsPerBlock() const;
int blockSizeX() const;
int blockSizeY() const;
int blockSizeZ() const;
int gridSizeX() const;
int gridSizeY() const;
int gridSizeZ() const;
};
} // namespace device::musa
......@@ -3,7 +3,7 @@
namespace device::musa {
Handle::Handle(infiniDevice_t device, int device_id)
: InfiniopHandle{device, device_id},
_internal(std::make_shared<Handle::Internal>()) {}
_internal(std::make_shared<Handle::Internal>(device_id)) {}
Handle::Handle(int device_id) : Handle(INFINI_DEVICE_MOORE, device_id) {}
......@@ -11,6 +11,19 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & {
return _internal;
}
Handle::Internal::Internal(int device_id) {
musaDeviceProp prop;
musaGetDeviceProperties(&prop, device_id);
_warp_size = prop.warpSize;
_max_threads_per_block = prop.maxThreadsPerBlock;
_block_size[0] = prop.maxThreadsDim[0];
_block_size[1] = prop.maxThreadsDim[1];
_block_size[2] = prop.maxThreadsDim[2];
_grid_size[0] = prop.maxGridSize[0];
_grid_size[1] = prop.maxGridSize[1];
_grid_size[2] = prop.maxGridSize[2];
}
infiniStatus_t Handle::Internal::useMublas(musaStream_t stream, const Fn<mublasHandle_t> &f) const {
std::unique_ptr<mublasHandle_t> handle;
auto opt_handle = mublas_handles.pop();
......@@ -40,6 +53,15 @@ infiniStatus_t Handle::Internal::useMudnn(musaStream_t stream, const Fn<::musa::
return INFINI_STATUS_SUCCESS;
}
int Handle::Internal::warpSize() const { return _warp_size; }
int Handle::Internal::maxThreadsPerBlock() const { return _max_threads_per_block; }
int Handle::Internal::blockSizeX() const { return _block_size[0]; }
int Handle::Internal::blockSizeY() const { return _block_size[1]; }
int Handle::Internal::blockSizeZ() const { return _block_size[2]; }
int Handle::Internal::gridSizeX() const { return _grid_size[0]; }
int Handle::Internal::gridSizeY() const { return _grid_size[1]; }
int Handle::Internal::gridSizeZ() const { return _grid_size[2]; }
infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
*handle_ptr = new Handle(INFINI_DEVICE_MOORE, device_id);
return INFINI_STATUS_SUCCESS;
......
#include "causal_softmax_cuda.cuh"
#include "../../../devices/cuda/cuda_common.cuh"
#include "causal_softmax_cuda.cuh"
#include "causal_softmax_kernel.cuh"
namespace op::causal_softmax::cuda {
......
#ifndef __CAUSAL_SOFTMAX_KERNEL_CUH__
#define __CAUSAL_SOFTMAX_KERNEL_CUH__
#include "../../../devices/cuda/cuda_common.cuh"
#include "../../../devices/cuda/cuda_kernel_common.cuh"
#include "../../../reduce/cuda/reduce.cuh"
template <unsigned int BLOCK_SIZE, typename Tdata, typename Tcompute>
INFINIOP_CUDA_KERNEL causalSoftmax(
......@@ -31,7 +32,11 @@ INFINIOP_CUDA_KERNEL causalSoftmax(
// 2 | * * * ... * * * |
// height: 3 col_id->
if (width + blockIdx.x >= threadIdx.x + height) {
#ifdef ENABLE_CUDA_API
y[col] = exp_(x[col] - max_);
#else
y[col] = exp(x[col] - max_);
#endif
} else {
y[col] = Tdata(0);
}
......
#ifndef __RMS_NORM_CUDA_KERNEL_H__
#define __RMS_NORM_CUDA_KERNEL_H__
#include "../../../devices/cuda/cuda_common.cuh"
#include <cub/block/block_reduce.cuh>
#include "../../../devices/cuda/cuda_kernel_common.cuh"
#include "../../../reduce/cuda/reduce.cuh"
template <unsigned int BLOCK_SIZE, typename Tdata, typename Tweight, typename Tcompute>
INFINIOP_CUDA_KERNEL rmsnormBlock(
......
#ifndef __RMS_NORM_MUSA_CUH__
#define __RMS_NORM_MUSA_CUH__
#include "../rms_norm.h"
DESCRIPTOR(musa)
#endif
#include "../../../devices/musa/common_musa.h"
#include "../cuda/rms_norm_kernel.cuh"
#include "rms_norm_musa.cuh"
namespace op::rms_norm::musa {
struct Descriptor::Opaque {
std::shared_ptr<device::musa::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc,
float epsilon) {
auto result = RMSNormInfo::create(y_desc, x_desc, w_desc, epsilon);
CHECK_RESULT(result);
auto info = result.take();
// only support contiguous last dimension
if (info.x_strides[1] != 1 || info.y_strides[1] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::musa::Handle *>(handle)->internal()},
std::move(info),
0,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
// launch kernel with different data types
template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel(
uint32_t batch_size, size_t dim,
void *y, infiniDtype_t atype, ptrdiff_t stride_y,
const void *x, ptrdiff_t stride_x,
const void *w, infiniDtype_t wtype,
float epsilon,
musaStream_t musa_stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
rmsnormBlock<BLOCK_SIZE, Tdata, Tweight, Tcompute><<<batch_size, BLOCK_SIZE, 0, musa_stream>>>( \
reinterpret_cast<Tdata *>(y), \
stride_y, \
reinterpret_cast<const Tdata *>(x), \
stride_x, \
reinterpret_cast<const Tweight *>(w), \
dim, \
epsilon)
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(float, float, float);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
#undef LAUNCH_KERNEL
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, const void *x, const void *w,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
auto stride_x = _info.x_strides[0];
auto stride_y = _info.y_strides[0];
auto dim = _info.dim();
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
auto musa_stream = reinterpret_cast<musaStream_t>(stream);
// launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::rms_norm::musa
......@@ -11,6 +11,9 @@
#ifdef ENABLE_ASCEND_API
#include "ascend/rms_norm_aclnn.h"
#endif
#ifdef ENABLE_MOORE_API
#include "musa/rms_norm_musa.cuh"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/rms_norm_kunlun.h"
#endif
......@@ -56,10 +59,8 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
return macaCreateRMSNormDescriptor((MacaHandle_t)handle, (RMSNormMacaDescriptor_t *)desc_ptr, y_desc, x_desc, w_desc, epsilon);
}
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu: {
return musaCreateRMSNormDescriptor((MusaHandle_t)handle, (RMSNormMusaDescriptor_t *)desc_ptr, y_desc, x_desc, w_desc, epsilon);
}
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, musa)
#endif
}
......@@ -98,10 +99,8 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d
return macaGetRMSNormWorkspaceSize((RMSNormMacaDescriptor_t)desc, size);
}
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu: {
return musaGetRMSNormWorkspaceSize((RMSNormMusaDescriptor_t)desc, size);
}
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, musa)
#endif
}
......@@ -141,10 +140,8 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works
return macaRMSNorm((RMSNormMacaDescriptor_t)desc, workspace, workspace_size, y, x, w, stream);
}
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu: {
return musaRMSNorm((RMSNormMusaDescriptor_t)desc, workspace, workspace_size, y, x, w, stream);
}
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, musa)
#endif
}
......@@ -183,10 +180,8 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t
return macaDestroyRMSNormDescriptor((RMSNormMacaDescriptor_t)desc);
}
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu: {
return musaDestroyRMSNormDescriptor((RMSNormMusaDescriptor_t)desc);
}
#ifdef ENABLE_MOORE_API
DESTROY(INFINI_DEVICE_MOORE, musa)
#endif
}
......
......@@ -4,6 +4,8 @@
#include "bang/infinirt_bang.h"
#include "cpu/infinirt_cpu.h"
#include "cuda/infinirt_cuda.cuh"
#include "maca/infinirt_maca.h"
#include "musa/infinirt_musa.h"
thread_local infiniDevice_t CURRENT_DEVICE_TYPE = INFINI_DEVICE_CPU;
thread_local int CURRENT_DEVICE_ID = 0;
......@@ -58,6 +60,12 @@ __C infiniStatus_t infinirtGetDevice(infiniDevice_t *device_ptr, int *device_id_
case INFINI_DEVICE_ASCEND: \
_status = infinirt::ascend::API PARAMS; \
break; \
case INFINI_DEVICE_METAX: \
_status = infinirt::maca::API PARAMS; \
break; \
case INFINI_DEVICE_MOORE: \
_status = infinirt::musa::API PARAMS; \
break; \
default: \
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; \
} \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment