Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
b4b2adf7
Commit
b4b2adf7
authored
Apr 10, 2025
by
zhangyue
Browse files
issue/111: rename kernel function and add mask comment
parent
98a72ab3
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
5 deletions
+7
-5
src/infiniop/devices/kunlun/kunlun_common.h
src/infiniop/devices/kunlun/kunlun_common.h
+3
-1
src/infiniop/ops/rms_norm/kunlun/rms_norm_kernel.xpu
src/infiniop/ops/rms_norm/kunlun/rms_norm_kernel.xpu
+4
-4
No files found.
src/infiniop/devices/kunlun/kunlun_common.h
View file @
b4b2adf7
...
...
@@ -7,7 +7,9 @@
#include "xpu/kernel/xtdk_simd.h"
#include "xpu/runtime.h"
// Get mask for vload_lm_ func
// Get mask for kunlun xpu 512bit register calculation
// if data is not enough to 512bit, padding zero and use
// mask to identify real data
// 0 - i bit 1, others 0
inline
__device__
float
lowerBitMask
(
int
i
)
{
return
(
1
<<
(
i
+
1
))
-
1
;
...
...
src/infiniop/ops/rms_norm/kunlun/rms_norm_kernel.xpu
View file @
b4b2adf7
...
...
@@ -5,7 +5,7 @@
#include "../../../reduce/kunlun/reduce_kunlun.h"
// Element wise mul used in x * w
static inline __device__ void element
_mul
(float *x, float *w, float *y, int count, float rms) {
static inline __device__ void element
wiseMulRms
(float *x, float *w, float *y, int count, float rms) {
int remain = count % 16;
int offset_last = count - remain;
// y[i] = w[i] * x[i] * rms for remainder
...
...
@@ -29,7 +29,7 @@ static inline __device__ void element_mul(float *x, float *w, float *y, int coun
// RmsNorm main kernel func
// kunlun2 has 8 cluster and 64 core
// Call it by rmsnorm<<<8, 32, stream>>>()
__global__ void rms
_n
orm
_f
32(float *y, long stride_y, const float *x, long stride_x, const float *w, int m, int n, float epsilon) {
__global__ void rms
N
orm
KernelF
32(float *y, long stride_y, const float *x, long stride_x, const float *w, int m, int n, float epsilon) {
// ncores in a cluster
int ncores = core_num();
// get cid of current core
...
...
@@ -103,7 +103,7 @@ __global__ void rms_norm_f32(float *y, long stride_y, const float *x, long strid
float ss = SM2REG_atomic(sm_output + m - m_start);
float rms = 1.0f / sqrt(ss / n + epsilon);
element
_mul
(x_local, w_local, y_local, nn, rms);
element
wiseMulRms
(x_local, w_local, y_local, nn, rms);
mfence();
auto y_ptr = y + m * stride_y + n_start;
...
...
@@ -117,7 +117,7 @@ __global__ void rms_norm_f32(float *y, long stride_y, const float *x, long strid
}
void rmsNormF32(void *y, long stride_y, const void *x, long stride_x, const void *w, int m, int n, float epsilon, XPUStream stream) {
rms
_n
orm
_f
32<<<8, 32, stream>>>((float *)y, stride_y, (const float *)x, stride_x, (const float *)w, m, n, epsilon);
rms
N
orm
KernelF
32<<<8, 32, stream>>>((float *)y, stride_y, (const float *)x, stride_x, (const float *)w, m, n, epsilon);
}
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment