Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
777fd6af
You need to sign in or sign up before continuing.
Commit
777fd6af
authored
Apr 01, 2025
by
zhangyue
Browse files
issue/111:添加rmsnorm以及算子编译流程
parent
65df17f7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
216 additions
and
0 deletions
+216
-0
src/infiniop/ops/rms_norm/kunlun/rms_norm_kernel.xpu
src/infiniop/ops/rms_norm/kunlun/rms_norm_kernel.xpu
+164
-0
xmake/kunlun.lua
xmake/kunlun.lua
+52
-0
No files found.
src/infiniop/ops/rms_norm/kunlun/rms_norm_kernel.xpu
0 → 100644
View file @
777fd6af
#ifndef __RMS_NORM_KUNLUN_KERNEL_H__
#define __RMS_NORM_KUNLUN_KERNEL_H__
#include "xpu/kernel/xtdk.h"
#include "xpu/kernel/xtdk_math.h"
#include "xpu/kernel/xtdk_simd.h"
// Get mask for vload_lm_ func
// 0 - i bit 1, others 0
static inline __device__ float lowerBitMask(int i) {
return (1 << (i + 1)) - 1;
}
// Use 16 floats instruction to calculate reduce
// data_ptr is the pointer of LM
static inline __device__ float sumSquaredF32(float *data_ptr, int count) {
__local__ float acc_buf[16];
int remain = count % 16;
int offset_last = count - remain;
int mask = lowerBitMask(remain - 1);
// Load last 16 data
float32x16_t v_last = vload_lm_float32x16_mz((data_ptr + offset_last), mask);
// Do v_last * v_last
v_last = vvmul_float32x16(v_last, v_last);
// for every 16 float data
for (int i = 0; i < offset_last; i += 16) {
float32x16_t v_0 = vload_lm_float32x16_mz(data_ptr + i);
// Do v_0 * v_0
v_0 = vvmul_float32x16(v_0, v_0);
// Add to v_last
v_last = vvadd_float32x16(v_last, v_0);
}
vstore_lm_float32x16_mz(acc_buf, v_last);
mfence();
float res = 0.0f;
for (int i = 0; i < 16; ++i) {
res += acc_buf[i];
}
return res;
}
// Element wise mul used in x * w
static inline __device__ void elementMul(float *x, float *w, float *y, int count, float rms) {
int remain = count % 16;
int offset_last = count - remain;
// y[i] = w[i] * x[i] * rms for remainder
for (int i = offset_last; i < count; i++) {
*(y + i) = *(w + i) * *(x + i) * rms;
}
mfence();
float32x16_t v_x;
float32x16_t v_w;
// Do x * w * rms
for (int i = 0; i < offset_last; i += 16) {
v_x = vload_lm_float32x16_mz(x + i);
v_w = vload_lm_float32x16_mz(w + i);
v_x = vvmul_float32x16(v_x, v_w);
v_x = svmul_float32x16(rms, v_x);
vstore_lm_float32x16((y + i), v_x);
mfence();
}
}
// Atomic add for reduce
static inline __device__ void atomic_add(__shared_ptr__ float *ptr, float value) {
int fail = 1;
while (fail) {
float a = SM2REG_atomic(ptr);
a = a + value;
fail = REG2SM_atomic(ptr, a);
}
}
// RmsNorm main kernel func
// kunlun2 has 8 cluster and 64 core
// Call it by rmsnorm<<<8, 32, stream>>>()
__global__ void rms_norm(float *y, long stride_y, float *x, long stride_x, float *w, int m, int n, float epsilon) {
// ncores in a cluster
int ncores = core_num();
// get cid of current core
int cid = core_id();
if (cid >= ncores) {
return;
}
// Divide m rows into all clusters equally
// if m % cluster_num() != 0, cluster_id < m % cluster_num() do 1 row more
// [m_start, m_end) is the range of m dim in current cluster
int m_start = m / cluster_num() * cluster_id() + min(m % cluster_num(), cluster_id());
int m_end = m_start + (m / cluster_num()) + (cluster_id() < (m % cluster_num()));
// max_nn is the max number of elements calculated on one core
const int max_nn = 1024;
// max_mm is the max number of rows calculated on one cluster
const int max_mm = 1024;
// LM cache for reduce
__local__ float x_local[max_nn];
// sm_output is shared mem cache for reduce
__shared__ float sm_output[max_mm];
// LM cache for elementwise mul
__local__ float y_local[max_nn];
__local__ float w_local[max_nn];
while (m_start < m_end) {
// init sm_output
for (int i = cid; i < m_end - m_start; i += ncores) {
sm_output[i] = 0.0f;
}
mfence();
sync_cluster();
// mm is the number of rows on current cluster
int mm = min(max_mm, m_end - m_start);
// each row will be devided to several blocks
// total_block is the number of blocks calculated on current cluster
// curr_block is the block calculated on current core
int total_block = mm * roundup_div(n, max_nn);
for (int curr_block = cid; curr_block < total_block; curr_block += ncores) {
// curr_m is the row of curr_block;
// curr_n_start is the first element of current row
// curr_nn is the number of elements of curr_block
int curr_m = curr_block % mm + m_start;
int curr_n_start = (curr_block / mm) * max_nn;
int curr_nn = min(max_nn, n - curr_n_start);
auto x_ptr = x + curr_m * stride_x + curr_n_start;
GM2LM(x_ptr, x_local, curr_nn * sizeof(float));
// do reduce
float ss = sumSquaredF32(x_local, curr_nn);
atomic_add(&sm_output[curr_m - m_start], ss);
}
mfence();
sync_cluster();
// do elementwise mul for every line
for (int blk = cid; blk < total_block; blk += ncores) {
int m = blk % mm + m_start;
int n_start = (blk / mm) * max_nn;
int nn = min(max_nn, n - n_start);
auto x_ptr = x + m * stride_x + n_start;
auto w_ptr = w + n_start;
GM2LM(x_ptr, x_local, nn * sizeof(float));
GM2LM(w_ptr, w_local, nn * sizeof(float));
float ss = SM2REG_atomic(sm_output + m - m_start);
float rms = 1.0f / sqrt(ss / n + epsilon);
elementMul(x_local, w_local, y_local, nn, rms);
mfence();
auto y_ptr = y + m * stride_y + n_start;
LM2GM(y_local, y_ptr, nn * sizeof(float));
}
mfence();
sync_cluster();
m_start += max_mm;
}
}
#endif
xmake/kunlun.lua
View file @
777fd6af
add_defines
(
"ENABLE_KUNLUN_API"
)
local
KUNLUN_HOME
=
os.getenv
(
"KUNLUN_HOME"
)
local
XTDK_DIR
=
path
.
join
(
KUNLUN_HOME
,
"XTDK"
)
-- Add include dirs
add_includedirs
(
path
.
join
(
KUNLUN_HOME
,
"include"
),
{
public
=
true
})
...
...
@@ -7,6 +8,52 @@ add_linkdirs(path.join(KUNLUN_HOME, "lib64"))
add_links
(
"xpurt"
)
add_links
(
"xpuapi"
)
rule
(
"xpu"
)
set_extensions
(
".xpu"
)
on_load
(
function
(
target
)
target
:
add
(
"includedirs"
,
path
.
join
(
os
.
projectdir
(),
"include"
))
end
)
on_build_file
(
function
(
target
,
sourcefile
)
local
objectfile
=
target
:
objectfile
(
sourcefile
)
local
basename
=
objectfile
:
gsub
(
"%.o$"
,
""
)
os
.
mkdir
(
path
.
directory
(
objectfile
))
local
cc
=
path
.
join
(
XTDK_DIR
,
"bin/clang++"
)
local
includedirs
=
table.concat
(
target
:
get
(
"includedirs"
),
" "
)
local
args
=
{
"--sysroot=/"
,
"--target=aarch64-linux-gnu"
,
"-fPIC"
,
"-pie"
,
"--xpu-arch=xpu2"
,
"--basename"
,
basename
,
"-std=c++11"
,
"-O2"
,
"-fno-builtin"
,
"-g"
,
"-c"
,
sourcefile
,
"-v"
}
for
_
,
includedir
in
ipairs
(
target
:
get
(
"includedirs"
))
do
table.insert
(
args
,
"-I"
..
includedir
)
end
-- print(args)
os
.
execv
(
cc
,
args
)
table.insert
(
target
:
objectfiles
(),
objectfile
)
table.insert
(
target
:
objectfiles
(),
basename
..
".device.bin.o"
)
print
(
target
:
objectfiles
())
end
)
rule_end
()
local
src_dir
=
path
.
join
(
os
.
projectdir
(),
"src"
,
"infiniop"
)
target
(
"infiniop-kunlun"
)
set_kind
(
"static"
)
add_deps
(
"infini-utils"
)
...
...
@@ -17,6 +64,11 @@ target("infiniop-kunlun")
set_languages
(
"cxx17"
)
add_files
(
"$(projectdir)/src/infiniop/devices/kunlun/*.cc"
,
"$(projectdir)/src/infiniop/ops/*/kunlun/*.cc"
)
-- compile handwriting kernel
local
xpu_files
=
os
.
files
(
src_dir
..
"/ops/*/kunlun/*.xpu"
)
if
#
xpu_files
>
0
then
add_files
(
xpu_files
,
{
rule
=
"xpu"
})
end
target_end
()
target
(
"infinirt-kunlun"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment