Commit 650f3057 authored by zhangyue's avatar zhangyue
Browse files

issue/111: 函数名称和参数类型的修改,不涉及对外接口

parent 1c762900
...@@ -9,18 +9,21 @@ ...@@ -9,18 +9,21 @@
// Get mask for vload_lm_ func // Get mask for vload_lm_ func
// 0 - i bit 1, others 0 // 0 - i bit 1, others 0
static inline __device__ float lowerBitMask(int i) { inline __device__ float lowerBitMask(int i) {
return (1 << (i + 1)) - 1; return (1 << (i + 1)) - 1;
} }
// Atomic add for reduce // Atomic add for reduce
static inline __device__ void atomic_add(__shared_ptr__ float *ptr, float value) { inline __device__ void atomicAddF32(__shared_ptr__ float *ptr, float value) {
int fail = 1; int success = 1;
while (fail) { while (success) {
// SM2REG read 32bit data to register
float a = SM2REG_atomic(ptr); float a = SM2REG_atomic(ptr);
a = a + value; a = a + value;
fail = REG2SM_atomic(ptr, a); success = REG2SM_atomic(ptr, a);
} }
} }
// TODO: atomicAddF16
// TODO: atomicAddI8
#endif #endif
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "../../../reduce/kunlun/reduce_kunlun.h" #include "../../../reduce/kunlun/reduce_kunlun.h"
// Element wise mul used in x * w // Element wise mul used in x * w
static inline __device__ void elementMul(float *x, float *w, float *y, int count, float rms) { static inline __device__ void element_mul(float *x, float *w, float *y, int count, float rms) {
int remain = count % 16; int remain = count % 16;
int offset_last = count - remain; int offset_last = count - remain;
// y[i] = w[i] * x[i] * rms for remainder // y[i] = w[i] * x[i] * rms for remainder
...@@ -29,7 +29,7 @@ static inline __device__ void elementMul(float *x, float *w, float *y, int count ...@@ -29,7 +29,7 @@ static inline __device__ void elementMul(float *x, float *w, float *y, int count
// RmsNorm main kernel func // RmsNorm main kernel func
// kunlun2 has 8 cluster and 64 core // kunlun2 has 8 cluster and 64 core
// Call it by rmsnorm<<<8, 32, stream>>>() // Call it by rmsnorm<<<8, 32, stream>>>()
__global__ void rms_norm(float *y, long stride_y, float *x, long stride_x, float *w, int m, int n, float epsilon) { __global__ void rms_norm_f32(float *y, long stride_y, const float *x, long stride_x, const float *w, int m, int n, float epsilon) {
// ncores in a cluster // ncores in a cluster
int ncores = core_num(); int ncores = core_num();
// get cid of current core // get cid of current core
...@@ -85,7 +85,7 @@ __global__ void rms_norm(float *y, long stride_y, float *x, long stride_x, float ...@@ -85,7 +85,7 @@ __global__ void rms_norm(float *y, long stride_y, float *x, long stride_x, float
// do reduce // do reduce
float ss = op::common_kunlun::reduce_op::sumSquaredF32(x_local, curr_nn); float ss = op::common_kunlun::reduce_op::sumSquaredF32(x_local, curr_nn);
atomic_add(&sm_output[curr_m - m_start], ss); atomicAddF32(&sm_output[curr_m - m_start], ss);
} }
mfence(); mfence();
sync_cluster(); sync_cluster();
...@@ -103,7 +103,7 @@ __global__ void rms_norm(float *y, long stride_y, float *x, long stride_x, float ...@@ -103,7 +103,7 @@ __global__ void rms_norm(float *y, long stride_y, float *x, long stride_x, float
float ss = SM2REG_atomic(sm_output + m - m_start); float ss = SM2REG_atomic(sm_output + m - m_start);
float rms = 1.0f / sqrt(ss / n + epsilon); float rms = 1.0f / sqrt(ss / n + epsilon);
elementMul(x_local, w_local, y_local, nn, rms); element_mul(x_local, w_local, y_local, nn, rms);
mfence(); mfence();
auto y_ptr = y + m * stride_y + n_start; auto y_ptr = y + m * stride_y + n_start;
...@@ -116,8 +116,8 @@ __global__ void rms_norm(float *y, long stride_y, float *x, long stride_x, float ...@@ -116,8 +116,8 @@ __global__ void rms_norm(float *y, long stride_y, float *x, long stride_x, float
} }
} }
void rms_norm_f32(void *y, long stride_y, const void *x, long stride_x, const void *w, int m, int n, float epsilon, XPUStream stream) { void rmsNormF32(void *y, long stride_y, const void *x, long stride_x, const void *w, int m, int n, float epsilon, XPUStream stream) {
rms_norm<<<8, 32, stream>>>((float *)y, stride_y, (float *)x, stride_x, (float *)w, m, n, epsilon); rms_norm_f32<<<8, 32, stream>>>((float *)y, stride_y, (const float *)x, stride_x, (const float *)w, m, n, epsilon);
} }
#endif #endif
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include <memory> #include <memory>
#include <stdint.h> #include <stdint.h>
void rms_norm_f32(void *y, long stride_y, const void *x, long stride_x, const void *w, int m, int n, float epsilon, XPUStream stream); void rmsNormF32(void *y, long stride_y, const void *x, long stride_x, const void *w, int m, int n, float epsilon, XPUStream stream);
namespace op::rms_norm::kunlun { namespace op::rms_norm::kunlun {
...@@ -53,7 +53,7 @@ infiniStatus_t launchKernel( ...@@ -53,7 +53,7 @@ infiniStatus_t launchKernel(
kunlunStream_t stream) { kunlunStream_t stream) {
if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) { if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
rms_norm_f32(y, static_cast<long>(stride_y), x, static_cast<long>(stride_x), w, m, n, epsilon, stream); rmsNormF32(y, static_cast<long>(stride_y), x, static_cast<long>(stride_x), w, m, n, epsilon, stream);
} else { } else {
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment