Commit 4d745cf9 authored by zhangyue's avatar zhangyue
Browse files

insert kunlun kernel to framework

parent 777fd6af
#ifndef __INFINIOP_KUNLUN_COMMON_H__
#define __INFINIOP_KUNLUN_COMMON_H__
// This header file will only be include by .xpu file
#include "xpu/kernel/xtdk.h"
#include "xpu/kernel/xtdk_math.h"
#include "xpu/kernel/xtdk_simd.h"
#include "xpu/runtime.h"
// Get mask for vload_lm_ func
// 0 - i bit 1, others 0
static inline __device__ float lowerBitMask(int i) {
return (1 << (i + 1)) - 1;
}
#endif
\ No newline at end of file
#ifndef __RMS_NORM_KUNLUN_KERNEL_H__ #ifndef __RMS_NORM_KUNLUN_KERNEL_H__
#define __RMS_NORM_KUNLUN_KERNEL_H__ #define __RMS_NORM_KUNLUN_KERNEL_H__
#include "xpu/kernel/xtdk.h" #include "../../../devices/kunlun/kunlun_common.h"
#include "xpu/kernel/xtdk_math.h" #include "../../../reduce/kunlun/reduce_kunlun.h"
#include "xpu/kernel/xtdk_simd.h"
// Get mask for vload_lm_ func
// 0 - i bit 1, others 0
static inline __device__ float lowerBitMask(int i) {
return (1 << (i + 1)) - 1;
}
// Use 16 floats instruction to calculate reduce
// data_ptr is the pointer of LM
static inline __device__ float sumSquaredF32(float *data_ptr, int count) {
__local__ float acc_buf[16];
int remain = count % 16;
int offset_last = count - remain;
int mask = lowerBitMask(remain - 1);
// Load last 16 data
float32x16_t v_last = vload_lm_float32x16_mz((data_ptr + offset_last), mask);
// Do v_last * v_last
v_last = vvmul_float32x16(v_last, v_last);
// for every 16 float data
for (int i = 0; i < offset_last; i += 16) {
float32x16_t v_0 = vload_lm_float32x16_mz(data_ptr + i);
// Do v_0 * v_0
v_0 = vvmul_float32x16(v_0, v_0);
// Add to v_last
v_last = vvadd_float32x16(v_last, v_0);
}
vstore_lm_float32x16_mz(acc_buf, v_last);
mfence();
float res = 0.0f;
for (int i = 0; i < 16; ++i) {
res += acc_buf[i];
}
return res;
}
// Element wise mul used in x * w // Element wise mul used in x * w
static inline __device__ void elementMul(float *x, float *w, float *y, int count, float rms) { static inline __device__ void elementMul(float *x, float *w, float *y, int count, float rms) {
...@@ -82,7 +47,7 @@ __global__ void rms_norm(float *y, long stride_y, float *x, long stride_x, float ...@@ -82,7 +47,7 @@ __global__ void rms_norm(float *y, long stride_y, float *x, long stride_x, float
if (cid >= ncores) { if (cid >= ncores) {
return; return;
} }
// Divide m rows into all clusters equally // Divide m rows into all clusters equally
// if m % cluster_num() != 0, cluster_id < m % cluster_num() do 1 row more // if m % cluster_num() != 0, cluster_id < m % cluster_num() do 1 row more
// [m_start, m_end) is the range of m dim in current cluster // [m_start, m_end) is the range of m dim in current cluster
...@@ -129,7 +94,7 @@ __global__ void rms_norm(float *y, long stride_y, float *x, long stride_x, float ...@@ -129,7 +94,7 @@ __global__ void rms_norm(float *y, long stride_y, float *x, long stride_x, float
GM2LM(x_ptr, x_local, curr_nn * sizeof(float)); GM2LM(x_ptr, x_local, curr_nn * sizeof(float));
// do reduce // do reduce
float ss = sumSquaredF32(x_local, curr_nn); float ss = op::common_kunlun::reduce_op::sumSquaredF32(x_local, curr_nn);
atomic_add(&sm_output[curr_m - m_start], ss); atomic_add(&sm_output[curr_m - m_start], ss);
} }
mfence(); mfence();
...@@ -161,4 +126,8 @@ __global__ void rms_norm(float *y, long stride_y, float *x, long stride_x, float ...@@ -161,4 +126,8 @@ __global__ void rms_norm(float *y, long stride_y, float *x, long stride_x, float
} }
} }
void rms_norm_f32(void *y, long stride_y, const void *x, long stride_x, const void *w, int m, int n, float epsilon, XPUStream stream) {
rms_norm<<<8, 32, stream>>>((float *)y, stride_y, (float *)x, stride_x, (float *)w, m, n, epsilon);
}
#endif #endif
#include "rms_norm_kunlun.h"
#include "../../../devices/kunlun/kunlun_handle.h"
#include <memory>
#include <stdint.h>
void rms_norm_f32(void *y, long stride_y, const void *x, long stride_x, const void *w, int m, int n, float epsilon, XPUStream stream);
namespace op::rms_norm::kunlun {
struct Descriptor::Opaque {
std::shared_ptr<device::kunlun::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc,
float epsilon) {
RMSNormInfo info;
CHECK_STATUS(createRMSNormInfo(&info, y_desc, x_desc, w_desc, epsilon));
if (info.x_strides[1] != 1 || info.y_strides[1] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
if (info.ndim() != 2) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
*desc_ptr = new Descriptor(
new Descriptor::Opaque{static_cast<device::kunlun::Handle *>(handle)->internal()},
info,
0,
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t launchKernel(
int m, int n,
void *y, infiniDtype_t atype, ptrdiff_t stride_y,
const void *x, ptrdiff_t stride_x,
const void *w, infiniDtype_t wtype,
float epsilon,
kunlunStream_t stream) {
if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
rms_norm_f32(y, static_cast<long>(stride_y), x, static_cast<long>(stride_x), w, m, n, epsilon, stream);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *y, const void *x, const void *w, void *stream) {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
auto stride_x = _info.x_strides[0];
auto stride_y = _info.y_strides[0];
int n = static_cast<int>(_info.dim());
int m = static_cast<int>(_info.shape[0]);
launchKernel(m, n, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, reinterpret_cast<kunlunStream_t>(stream));
return INFINI_STATUS_SUCCESS;
}
} // namespace op::rms_norm::kunlun
\ No newline at end of file
#ifndef __RMS_NORM_KUNLUN_H__
#define __RMS_NORM_KUNLUN_H__
#include "../rms_norm.h"
DESCRIPTOR(kunlun)
#endif
...@@ -8,6 +8,9 @@ ...@@ -8,6 +8,9 @@
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
#include "cuda/rms_norm_cuda.cuh" #include "cuda/rms_norm_cuda.cuh"
#endif #endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/rms_norm_kunlun.h"
#endif
__C infiniStatus_t infiniopCreateRMSNormDescriptor( __C infiniStatus_t infiniopCreateRMSNormDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
...@@ -34,6 +37,9 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor( ...@@ -34,6 +37,9 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
CREATE(INFINI_DEVICE_NVIDIA, cuda) CREATE(INFINI_DEVICE_NVIDIA, cuda)
#endif #endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun)
#endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
return bangCreateRMSNormDescriptor((BangHandle_t)handle, (RMSNormBangDescriptor_t *)desc_ptr, y_desc, x_desc, w_desc, epsilon); return bangCreateRMSNormDescriptor((BangHandle_t)handle, (RMSNormBangDescriptor_t *)desc_ptr, y_desc, x_desc, w_desc, epsilon);
...@@ -80,6 +86,9 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d ...@@ -80,6 +86,9 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
GET(INFINI_DEVICE_NVIDIA, cuda) GET(INFINI_DEVICE_NVIDIA, cuda)
#endif #endif
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun)
#endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
return bangGetRMSNormWorkspaceSize((RMSNormBangDescriptor_t)desc, size); return bangGetRMSNormWorkspaceSize((RMSNormBangDescriptor_t)desc, size);
...@@ -123,6 +132,9 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works ...@@ -123,6 +132,9 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
CALCULATE(INFINI_DEVICE_NVIDIA, cuda) CALCULATE(INFINI_DEVICE_NVIDIA, cuda)
#endif #endif
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun)
#endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
return bangRMSNorm((RMSNormBangDescriptor_t)desc, workspace, workspace_size, y, x, w, stream); return bangRMSNorm((RMSNormBangDescriptor_t)desc, workspace, workspace_size, y, x, w, stream);
...@@ -170,6 +182,9 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t ...@@ -170,6 +182,9 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
DESTROY(INFINI_DEVICE_NVIDIA, cuda) DESTROY(INFINI_DEVICE_NVIDIA, cuda)
#endif #endif
#ifdef ENABLE_KUNLUN_API
DESTROY(INFINI_DEVICE_KUNLUN, kunlun)
#endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
return bangDestroyRMSNormDescriptor((RMSNormBangDescriptor_t)desc); return bangDestroyRMSNormDescriptor((RMSNormBangDescriptor_t)desc);
......
#ifndef __INFINIOP_REDUCE_KUNLUN_H__
#define __INFINIOP_REDUCE_KUNLUN_H__
#include "../../devices/kunlun/kunlun_common.h"
namespace op::common_kunlun::reduce_op {
// Use 16 floats instruction to calculate reduce
// data_ptr is the pointer of LM
static inline __device__ float sumSquaredF32(float *data_ptr, int count) {
__local__ float acc_buf[16];
int remain = count % 16;
int offset_last = count - remain;
int mask = lowerBitMask(remain - 1);
// Load last 16 data
float32x16_t v_last = vload_lm_float32x16_mz((data_ptr + offset_last), mask);
// Do v_last * v_last
v_last = vvmul_float32x16(v_last, v_last);
// for every 16 float data
for (int i = 0; i < offset_last; i += 16) {
float32x16_t v_0 = vload_lm_float32x16_mz(data_ptr + i);
// Do v_0 * v_0
v_0 = vvmul_float32x16(v_0, v_0);
// Add to v_last
v_last = vvadd_float32x16(v_last, v_0);
}
vstore_lm_float32x16_mz(acc_buf, v_last);
mfence();
float res = 0.0f;
for (int i = 0; i < 16; ++i) {
res += acc_buf[i];
}
return res;
}
} // namespace op::common_kunlun::reduce_op
#endif
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment