Commit b4b2adf7 authored by zhangyue's avatar zhangyue
Browse files

issue/111: rename kernel function and add mask comment

parent 98a72ab3
......@@ -7,7 +7,9 @@
#include "xpu/kernel/xtdk_simd.h"
#include "xpu/runtime.h"
// Get mask for vload_lm_ func
// Get mask for kunlun xpu 512bit register calculation
// if data is not enough to 512bit, padding zero and use
// mask to identify real data
// 0 - i bit 1, others 0
inline __device__ float lowerBitMask(int i) {
return (1 << (i + 1)) - 1;
......
......@@ -5,7 +5,7 @@
#include "../../../reduce/kunlun/reduce_kunlun.h"
// Element wise mul used in x * w
static inline __device__ void element_mul(float *x, float *w, float *y, int count, float rms) {
static inline __device__ void elementwiseMulRms(float *x, float *w, float *y, int count, float rms) {
int remain = count % 16;
int offset_last = count - remain;
// y[i] = w[i] * x[i] * rms for remainder
......@@ -29,7 +29,7 @@ static inline __device__ void element_mul(float *x, float *w, float *y, int coun
// RmsNorm main kernel func
// kunlun2 has 8 cluster and 64 core
// Call it by rmsnorm<<<8, 32, stream>>>()
__global__ void rms_norm_f32(float *y, long stride_y, const float *x, long stride_x, const float *w, int m, int n, float epsilon) {
__global__ void rmsNormKernelF32(float *y, long stride_y, const float *x, long stride_x, const float *w, int m, int n, float epsilon) {
// ncores in a cluster
int ncores = core_num();
// get cid of current core
......@@ -103,7 +103,7 @@ __global__ void rms_norm_f32(float *y, long stride_y, const float *x, long strid
float ss = SM2REG_atomic(sm_output + m - m_start);
float rms = 1.0f / sqrt(ss / n + epsilon);
element_mul(x_local, w_local, y_local, nn, rms);
elementwiseMulRms(x_local, w_local, y_local, nn, rms);
mfence();
auto y_ptr = y + m * stride_y + n_start;
......@@ -117,7 +117,7 @@ __global__ void rms_norm_f32(float *y, long stride_y, const float *x, long strid
}
void rmsNormF32(void *y, long stride_y, const void *x, long stride_x, const void *w, int m, int n, float epsilon, XPUStream stream) {
rms_norm_f32<<<8, 32, stream>>>((float *)y, stride_y, (const float *)x, stride_x, (const float *)w, m, n, epsilon);
rmsNormKernelF32<<<8, 32, stream>>>((float *)y, stride_y, (const float *)x, stride_x, (const float *)w, m, n, epsilon);
}
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment