Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
eb3972eb
Commit
eb3972eb
authored
Aug 20, 2025
by
zhangyue
Browse files
issues/385 p800 rmsnorm 支持多精度
parent
60ca4508
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
306 additions
and
160 deletions
+306
-160
src/infiniop/devices/kunlun/kunlun_kernel_common.h
src/infiniop/devices/kunlun/kunlun_kernel_common.h
+84
-11
src/infiniop/ops/rms_norm/kunlun/kernel.h
src/infiniop/ops/rms_norm/kunlun/kernel.h
+35
-0
src/infiniop/ops/rms_norm/kunlun/rms_norm_kernel.xpu
src/infiniop/ops/rms_norm/kunlun/rms_norm_kernel.xpu
+0
-125
src/infiniop/ops/rms_norm/kunlun/rms_norm_kunlun.h
src/infiniop/ops/rms_norm/kunlun/rms_norm_kunlun.h
+10
-0
src/infiniop/ops/rms_norm/kunlun/rms_norm_kunlun.xpu
src/infiniop/ops/rms_norm/kunlun/rms_norm_kunlun.xpu
+137
-0
src/infiniop/reduce/kunlun/reduce_kunlun.h
src/infiniop/reduce/kunlun/reduce_kunlun.h
+40
-24
No files found.
src/infiniop/devices/kunlun/kunlun_kernel_common.h
View file @
eb3972eb
...
...
@@ -4,17 +4,26 @@
// This header file will only be include by .xpu file
#include "xpu/runtime.h"
#include <xpu/kernel/xtdk.h>
#include <xpu/kernel/xtdk_atomic_sm_xpu3.h>
#include <xpu/kernel/xtdk_bf16.h>
#include <xpu/kernel/xtdk_math.h>
#include <xpu/kernel/xtdk_simd.h>
#include <xpu/kernel/xtdk_trigonometric.h>
// #include <xpu/kernel/xtdk_io.h>
namespace
device
::
kunlun
::
kernel
{
#define SM_SIZE 10240
/**
* @brief Define ptrdiff_t and size_t for kunlun xpu
* ptrdiff_t is 32 bit, size_t is 32 bit in xpu kernel
* We padding it into 64 bit for convience of DATACOPY
*/
typedef
struct
_ptrdiff_t
{
int32_t
value
;
// 32 bit
int32_t
padding
;
// 32 bit
}
_ptrdiff_t
;
// same as ptrdiff
typedef
struct
_size_t
{
uint32_t
value
;
...
...
@@ -29,17 +38,83 @@ inline __device__ float lowerBitMask(int i) {
return
(
1
<<
(
i
+
1
))
-
1
;
}
// Atomic add for reduce
inline
__device__
void
atomicAddF32
(
__shared_ptr__
float
*
ptr
,
float
value
)
{
int
success
=
1
;
while
(
success
)
{
// SM2REG read 32bit data to register
float
a
=
SM2REG_atomic
(
ptr
);
a
=
a
+
value
;
success
=
REG2SM_atomic
(
ptr
,
a
);
/**
* @brief Load data from shared memory
* @param p: pointer to shared memory
* @return loaded value
*/
template
<
typename
T
>
__device__
inline
T
loadsm
(
__shared_ptr__
const
T
*
p
)
{
T
v
;
if
constexpr
(
std
::
is_same
<
T
,
half
>::
value
||
std
::
is_same
<
T
,
bfloat16_t
>::
value
)
{
__builtin_memcpy
(
&
v
,
p
,
sizeof
(
T
));
}
else
{
v
=
*
p
;
}
return
v
;
}
// Load len data from shared memory
template
<
typename
T
>
__device__
inline
void
loadsm
(
__shared_ptr__
const
T
*
p
,
T
*
v
,
int
len
)
{
__builtin_memcpy
(
v
,
p
,
len
*
sizeof
(
T
));
}
/**
* @brief Convert data type. All data is in local memory
* @param v: input value
* @return output value
*/
template
<
typename
Tout
,
typename
Tin
>
__device__
inline
Tout
to
(
Tin
v
)
{
if
constexpr
(
std
::
is_same
<
Tin
,
half
>::
value
)
{
return
__half2float
(
v
);
}
else
if
constexpr
(
std
::
is_same
<
Tin
,
bfloat16_t
>::
value
)
{
return
__bfloat162float
(
v
);
}
else
{
return
static_cast
<
Tout
>
(
v
);
}
}
/**
* @brief atomicAdd for kunlun xpu
* @param ptr: pointer to shared memory
* @param value: value to add
*/
template
<
typename
T
>
inline
__device__
T
atomicAdd
(
__shared_ptr__
T
*
ptr
,
T
value
)
{
T
x
=
atomicadd
(
ptr
,
value
);
return
x
;
}
// Specialize atomicAdd for half
template
<
>
inline
__device__
half
atomicAdd
<
half
>
(
__shared_ptr__
half
*
ptr
,
half
value
)
{
ticket_lock_mix
();
__half
old
=
loadsm
(
ptr
);
float
of
=
__half2float
(
old
);
float
vf
=
__half2float
(
value
);
float
sumf
=
of
+
vf
;
half
sum
=
__float2half_rn
(
sumf
);
*
ptr
=
sum
;
mfence_sm
();
ticket_unlock_mix
();
return
old
;
}
// Specialize atomicAdd for bfloat16_t
template
<
>
inline
__device__
bfloat16_t
atomicAdd
<
bfloat16_t
>
(
__shared_ptr__
bfloat16_t
*
ptr
,
bfloat16_t
value
)
{
ticket_lock_mix
();
bfloat16_t
old
=
loadsm
(
ptr
);
float
of
=
__bfloat162float
(
old
);
float
vf
=
__bfloat162float
(
value
);
float
sumf
=
of
+
vf
;
bfloat16_t
sum
=
__float2bfloat16_rn
(
sumf
);
*
ptr
=
sum
;
mfence_sm
();
ticket_unlock_mix
();
return
old
;
}
/**
* @brief Get index of broadcasted input
* flat_index: flatten index of output tensor
...
...
@@ -85,5 +160,3 @@ inline __device__ int indexToOffset(
}
// namespace device::kunlun::kernel
#endif // __INFINIOP_KUNLUN_KERNEL_COMMON_H__
// TODO: atomicAddF16
// TODO: atomicAddI8
src/infiniop/ops/rms_norm/kunlun/kernel.h
0 → 100644
View file @
eb3972eb
#ifndef __RMS_NORM_KUNLUN_KERNEL_H__
#define __RMS_NORM_KUNLUN_KERNEL_H__
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "../../../reduce/kunlun/reduce_kunlun.h"
using
namespace
device
::
kunlun
::
kernel
;
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tcompute
,
typename
Tdata
,
typename
Tweight
>
__device__
void
rmsnormBlock
(
__shared_ptr__
Tdata
*
y
,
__shared_ptr__
const
Tdata
*
x
,
__shared_ptr__
const
Tweight
*
w
,
size_t
dim
,
float
epsilon
)
{
// Block reduce sum of x^2
Tcompute
ss
=
op
::
common_kunlun
::
reduce_op
::
sumSquared
<
BLOCK_SIZE
,
Tdata
,
Tcompute
>
(
x
,
dim
);
__shared__
Tcompute
rms
;
if
(
core_id
()
==
0
)
{
rms
=
Tcompute
(
rsqrt
(
ss
/
Tcompute
(
dim
)
+
epsilon
));
}
sync_cluster
();
// Copy contiguous x, w into local mem (load from shared memory safely)
for
(
size_t
i
=
core_id
();
i
<
dim
;
i
+=
BLOCK_SIZE
)
{
Tdata
xi
=
loadsm
(
x
+
i
);
Tweight
wi
=
loadsm
(
w
+
i
);
y
[
i
]
=
static_cast
<
Tdata
>
(
to
<
Tcompute
>
(
xi
)
*
to
<
Tcompute
>
(
wi
)
*
rms
);
}
sync_cluster
();
}
#endif
src/infiniop/ops/rms_norm/kunlun/rms_norm_kernel.xpu
deleted
100644 → 0
View file @
60ca4508
#ifndef __RMS_NORM_KUNLUN_KERNEL_XPU__
#define __RMS_NORM_KUNLUN_KERNEL_XPU__
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "../../../reduce/kunlun/reduce_kunlun.h"
using namespace device::kunlun::kernel;
// Element wise mul used in x * w
static inline __device__ void elementwiseMulRms(float *x, float *w, float *y, int count, float rms) {
int remain = count % 16;
int offset_last = count - remain;
// y[i] = w[i] * x[i] * rms for remainder
for (int i = offset_last; i < count; i++) {
*(y + i) = *(w + i) * *(x + i) * rms;
}
mfence();
float32x16_t v_x;
float32x16_t v_w;
// Do x * w * rms
for (int i = 0; i < offset_last; i += 16) {
v_x = vload_lm_float32x16_mz(x + i);
v_w = vload_lm_float32x16_mz(w + i);
v_x = vvmul_float32x16(v_x, v_w);
v_x = svmul_float32x16(rms, v_x);
vstore_lm_float32x16((y + i), v_x);
mfence();
}
}
// RmsNorm main kernel func
// kunlun2 has 8 cluster and 64 core
// Call it by rmsnorm<<<8, 32, stream>>>()
__global__ void rmsNormKernelF32(float *y, long stride_y, const float *x, long stride_x, const float *w, int m, int n, float epsilon) {
// ncores in a cluster
int ncores = core_num();
// get cid of current core
int cid = core_id();
if (cid >= ncores) {
return;
}
// Divide m rows into all clusters equally
// if m % cluster_num() != 0, cluster_id < m % cluster_num() do 1 row more
// [m_start, m_end) is the range of m dim in current cluster
int m_start = m / cluster_num() * cluster_id() + min(m % cluster_num(), cluster_id());
int m_end = m_start + (m / cluster_num()) + (cluster_id() < (m % cluster_num()));
// max_nn is the max number of elements calculated on one core
const int max_nn = 1024;
// max_mm is the max number of rows calculated on one cluster
const int max_mm = 1024;
// LM cache for reduce
__local__ float x_local[max_nn];
// sm_output is shared mem cache for reduce
__shared__ float sm_output[max_mm];
// LM cache for elementwise mul
__local__ float y_local[max_nn];
__local__ float w_local[max_nn];
while (m_start < m_end) {
// init sm_output
for (int i = cid; i < m_end - m_start; i += ncores) {
sm_output[i] = 0.0f;
}
mfence();
sync_cluster();
// mm is the number of rows on current cluster
int mm = min(max_mm, m_end - m_start);
// each row will be devided to several blocks
// total_block is the number of blocks calculated on current cluster
// curr_block is the block calculated on current core
int total_block = mm * roundup_div(n, max_nn);
for (int curr_block = cid; curr_block < total_block; curr_block += ncores) {
// curr_m is the row of curr_block;
// curr_n_start is the first element of current row
// curr_nn is the number of elements of curr_block
int curr_m = curr_block % mm + m_start;
int curr_n_start = (curr_block / mm) * max_nn;
int curr_nn = min(max_nn, n - curr_n_start);
auto x_ptr = x + curr_m * stride_x + curr_n_start;
GM2LM(x_ptr, x_local, curr_nn * sizeof(float));
// do reduce
float ss = op::common_kunlun::reduce_op::sumSquaredF32(x_local, curr_nn);
atomicAddF32(&sm_output[curr_m - m_start], ss);
}
mfence();
sync_cluster();
// do elementwise mul for every line
for (int blk = cid; blk < total_block; blk += ncores) {
int m = blk % mm + m_start;
int n_start = (blk / mm) * max_nn;
int nn = min(max_nn, n - n_start);
auto x_ptr = x + m * stride_x + n_start;
auto w_ptr = w + n_start;
GM2LM(x_ptr, x_local, nn * sizeof(float));
GM2LM(w_ptr, w_local, nn * sizeof(float));
float ss = SM2REG_atomic(sm_output + m - m_start);
float rms = 1.0f / sqrt(ss / n + epsilon);
elementwiseMulRms(x_local, w_local, y_local, nn, rms);
mfence();
auto y_ptr = y + m * stride_y + n_start;
LM2GM(y_local, y_ptr, nn * sizeof(float));
}
mfence();
sync_cluster();
m_start += max_mm;
}
}
void rmsNormF32(void *y, long stride_y, const void *x, long stride_x, const void *w, int m, int n, float epsilon, XPUStream stream) {
rmsNormKernelF32<<<8, 32, stream>>>((float *)y, stride_y, (const float *)x, stride_x, (const float *)w, m, n, epsilon);
}
#endif
src/infiniop/ops/rms_norm/kunlun/rms_norm_kunlun.h
View file @
eb3972eb
...
...
@@ -5,4 +5,14 @@
DESCRIPTOR
(
kunlun
)
#define INSTANTIATE_RMSNORM_KERNEL(BLOCK_SIZE, Tcompute, Tdata, Tweight) \
template __global__ void rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight>( \
Tdata * y, \
int32_t stride_y, \
const Tdata *x, \
int32_t stride_x, \
const Tweight *w, \
uint32_t dim, \
float epsilon);
#endif
src/infiniop/ops/rms_norm/kunlun/rms_norm_kunlun.
cc
→
src/infiniop/ops/rms_norm/kunlun/rms_norm_kunlun.
xpu
View file @
eb3972eb
#include "rms_norm_kunlun.h"
#include "../../../devices/kunlun/kunlun_common.h"
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "kernel.h"
#include "rms_norm_kunlun.h"
#include <memory>
#include <stdint.h>
void
rmsNormF32
(
void
*
y
,
long
stride_y
,
const
void
*
x
,
long
stride_x
,
const
void
*
w
,
int
m
,
int
n
,
float
epsilon
,
XPUStream
stream
);
// Kernel function for computing RMS-norm
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
__global__ void rmsnormKernel(
Tdata *y,
int32_t stride_y,
const Tdata *x,
int32_t stride_x,
const Tweight *w,
uint32_t dim,
float epsilon) {
__shared__ Tdata x_sm[SM_SIZE / sizeof(Tdata)];
__shared__ Tweight w_sm[SM_SIZE / sizeof(Tweight)];
__shared__ Tdata y_sm[SM_SIZE / sizeof(Tdata)];
// Copy x and w to shared memory in 0 core
if (core_id() == 0) {
GM2SM_ASYNC(x + stride_x * cluster_id(), x_sm, dim * sizeof(Tdata));
GM2SM_ASYNC(w, w_sm, dim * sizeof(Tweight));
}
sync_cluster();
// Compute RMS-norm in shared memory
rmsnormBlock<BLOCK_SIZE, Tcompute>(y_sm, x_sm, w_sm, dim, epsilon);
if (core_id() == 0) {
SM2GM_ASYNC(y_sm, y + stride_y * cluster_id(), dim * sizeof(Tdata));
}
sync_cluster();
}
// Instantiate the kernel for different data types and block sizes
INSTANTIATE_RMSNORM_KERNEL(64, float, float, float);
INSTANTIATE_RMSNORM_KERNEL(64, float, bfloat16_t, float);
INSTANTIATE_RMSNORM_KERNEL(64, float, bfloat16_t, bfloat16_t);
INSTANTIATE_RMSNORM_KERNEL(64, float, half, float);
INSTANTIATE_RMSNORM_KERNEL(64, float, half, half);
namespace op::rms_norm::kunlun {
...
...
@@ -24,13 +62,11 @@ infiniStatus_t Descriptor::create(
float epsilon) {
auto result = RMSNormInfo::create(y_desc, x_desc, w_desc, epsilon);
CHECK_RESULT(result);
auto info = result.take();
if (info.x_strides[1] != 1 || info.y_strides[1] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
if (info.ndim() != 2) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
...
...
@@ -44,35 +80,57 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_SUCCESS;
}
template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel(
int
m
,
int
n
,
u
int
32_t batch_size
,
u
int
32_t dim
,
void *y, infiniDtype_t atype, ptrdiff_t stride_y,
const void *x, ptrdiff_t stride_x,
const void *w, infiniDtype_t wtype,
float epsilon,
kunlunStream_t stream) {
if
(
atype
==
INFINI_DTYPE_F32
&&
wtype
==
INFINI_DTYPE_F32
)
{
rmsNormF32
(
y
,
static_cast
<
long
>
(
stride_y
),
x
,
static_cast
<
long
>
(
stride_x
),
w
,
m
,
n
,
epsilon
,
stream
);
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size, BLOCK_SIZE, stream>>>( \
static_cast<Tdata *>(y), stride_y, \
static_cast<const Tdata *>(x), stride_x, \
static_cast<const Tweight *>(w), dim, epsilon);
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(bfloat16_t, bfloat16_t, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(bfloat16_t, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(float, float, float);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
#undef LAUNCH_KERNEL
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
const
void
*
x
,
const
void
*
w
,
void
*
stream
)
const
{
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, const void *x, const void *w,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
auto
stride_x
=
_info
.
x_strides
[
0
];
auto
stride_y
=
_info
.
y_strides
[
0
];
int
n
=
static_cast
<
int
>
(
_info
.
dim
());
int
m
=
static_cast
<
int
>
(
_info
.
shape
[
0
]);
auto stride_x = static_cast<int32_t>(_info.x_strides[0]);
auto stride_y = static_cast<int32_t>(_info.y_strides[0]);
auto dim = static_cast<uint32_t>(_info.dim());
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
// launch kernel with different block sizes
CHECK_STATUS(launchKernel<64>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, stream));
launchKernel
(
m
,
n
,
y
,
_info
.
atype
,
stride_y
,
x
,
stride_x
,
w
,
_info
.
wtype
,
_info
.
epsilon
,
reinterpret_cast
<
kunlunStream_t
>
(
stream
));
return INFINI_STATUS_SUCCESS;
}
...
...
src/infiniop/reduce/kunlun/reduce_kunlun.h
View file @
eb3972eb
...
...
@@ -7,32 +7,48 @@ namespace op::common_kunlun::reduce_op {
using
namespace
device
::
kunlun
::
kernel
;
// Use 16 floats instruction to calculate reduce
// data_ptr is the pointer of LM
static
inline
__device__
float
sumSquaredF32
(
float
*
data_ptr
,
int
count
)
{
__local__
float
acc_buf
[
16
];
int
remain
=
count
%
16
;
int
offset_last
=
count
-
remain
;
int
mask
=
lowerBitMask
(
remain
-
1
);
// Load last 16 data
float32x16_t
v_last
=
vload_lm_float32x16_mz
((
data_ptr
+
offset_last
),
mask
);
// Do v_last * v_last
v_last
=
vvmul_float32x16
(
v_last
,
v_last
);
// for every 16 float data
for
(
int
i
=
0
;
i
<
offset_last
;
i
+=
16
)
{
float32x16_t
v_0
=
vload_lm_float32x16_mz
(
data_ptr
+
i
);
// Do v_0 * v_0
v_0
=
vvmul_float32x16
(
v_0
,
v_0
);
// Add to v_last
v_last
=
vvadd_float32x16
(
v_last
,
v_0
);
// Sum(x^2) on contiguous data of length count
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tdata
,
typename
Tcompute
>
__device__
inline
Tcompute
sumSquared
(
__shared_ptr__
const
Tdata
*
data_ptr
,
size_t
count
)
{
Tcompute
ss
=
0
;
for
(
size_t
i
=
core_id
();
i
<
count
;
i
+=
BLOCK_SIZE
)
{
Tdata
xi
=
loadsm
(
data_ptr
+
i
);
ss
+=
to
<
Tcompute
>
(
xi
)
*
to
<
Tcompute
>
(
xi
);
}
__shared__
Tcompute
temp_storage
;
if
(
core_id
()
==
0
)
{
temp_storage
=
0
;
}
sync_cluster
();
atomicAdd
(
&
temp_storage
,
ss
);
sync_cluster
();
return
temp_storage
;
}
// Sum(x) on contiguous data of length count
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tdata
,
typename
Tcompute
>
__device__
inline
Tcompute
sum
(
__shared_ptr__
const
Tdata
*
data_ptr
,
size_t
count
)
{
Tcompute
ss
=
0
;
for
(
size_t
i
=
core_id
();
i
<
count
;
i
+=
BLOCK_SIZE
)
{
Tdata
xi
=
loadsm
(
data_ptr
+
i
);
ss
+=
to
<
Tcompute
>
(
xi
);
}
vstore_lm_float32x16_mz
(
acc_buf
,
v_last
);
mfence
();
float
res
=
0.0
f
;
for
(
int
i
=
0
;
i
<
16
;
++
i
)
{
res
+=
acc_buf
[
i
];
__shared__
Tcompute
temp_storage
;
if
(
core_id
()
==
0
)
{
temp_storage
=
0
;
}
return
res
;
sync_cluster
();
atomicAdd
(
&
temp_storage
,
ss
);
sync_cluster
();
return
temp_storage
;
}
}
// namespace op::common_kunlun::reduce_op
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment