Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
650f3057
Commit
650f3057
authored
Apr 10, 2025
by
zhangyue
Browse files
issue/111: 函数名称和参数类型的修改,不涉及对外接口
parent
1c762900
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
13 deletions
+16
-13
src/infiniop/devices/kunlun/kunlun_common.h
src/infiniop/devices/kunlun/kunlun_common.h
+8
-5
src/infiniop/ops/rms_norm/kunlun/rms_norm_kernel.xpu
src/infiniop/ops/rms_norm/kunlun/rms_norm_kernel.xpu
+6
-6
src/infiniop/ops/rms_norm/kunlun/rms_norm_kunlun.cc
src/infiniop/ops/rms_norm/kunlun/rms_norm_kunlun.cc
+2
-2
No files found.
src/infiniop/devices/kunlun/kunlun_common.h
View file @
650f3057
...
@@ -9,18 +9,21 @@
...
@@ -9,18 +9,21 @@
// Get mask for vload_lm_ func
// Get mask for vload_lm_ func
// 0 - i bit 1, others 0
// 0 - i bit 1, others 0
static
inline
__device__
float
lowerBitMask
(
int
i
)
{
inline
__device__
float
lowerBitMask
(
int
i
)
{
return
(
1
<<
(
i
+
1
))
-
1
;
return
(
1
<<
(
i
+
1
))
-
1
;
}
}
// Atomic add for reduce
// Atomic add for reduce
static
inline
__device__
void
atomic_add
(
__shared_ptr__
float
*
ptr
,
float
value
)
{
inline
__device__
void
atomicAddF32
(
__shared_ptr__
float
*
ptr
,
float
value
)
{
int
fail
=
1
;
int
success
=
1
;
while
(
fail
)
{
while
(
success
)
{
// SM2REG read 32bit data to register
float
a
=
SM2REG_atomic
(
ptr
);
float
a
=
SM2REG_atomic
(
ptr
);
a
=
a
+
value
;
a
=
a
+
value
;
fail
=
REG2SM_atomic
(
ptr
,
a
);
success
=
REG2SM_atomic
(
ptr
,
a
);
}
}
}
}
// TODO: atomicAddF16
// TODO: atomicAddI8
#endif
#endif
src/infiniop/ops/rms_norm/kunlun/rms_norm_kernel.xpu
View file @
650f3057
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
#include "../../../reduce/kunlun/reduce_kunlun.h"
#include "../../../reduce/kunlun/reduce_kunlun.h"
// Element wise mul used in x * w
// Element wise mul used in x * w
static inline __device__ void element
M
ul(float *x, float *w, float *y, int count, float rms) {
static inline __device__ void element
_m
ul(float *x, float *w, float *y, int count, float rms) {
int remain = count % 16;
int remain = count % 16;
int offset_last = count - remain;
int offset_last = count - remain;
// y[i] = w[i] * x[i] * rms for remainder
// y[i] = w[i] * x[i] * rms for remainder
...
@@ -29,7 +29,7 @@ static inline __device__ void elementMul(float *x, float *w, float *y, int count
...
@@ -29,7 +29,7 @@ static inline __device__ void elementMul(float *x, float *w, float *y, int count
// RmsNorm main kernel func
// RmsNorm main kernel func
// kunlun2 has 8 cluster and 64 core
// kunlun2 has 8 cluster and 64 core
// Call it by rmsnorm<<<8, 32, stream>>>()
// Call it by rmsnorm<<<8, 32, stream>>>()
__global__ void rms_norm(float *y, long stride_y, float *x, long stride_x, float *w, int m, int n, float epsilon) {
__global__ void rms_norm
_f32
(float *y, long stride_y,
const
float *x, long stride_x,
const
float *w, int m, int n, float epsilon) {
// ncores in a cluster
// ncores in a cluster
int ncores = core_num();
int ncores = core_num();
// get cid of current core
// get cid of current core
...
@@ -85,7 +85,7 @@ __global__ void rms_norm(float *y, long stride_y, float *x, long stride_x, float
...
@@ -85,7 +85,7 @@ __global__ void rms_norm(float *y, long stride_y, float *x, long stride_x, float
// do reduce
// do reduce
float ss = op::common_kunlun::reduce_op::sumSquaredF32(x_local, curr_nn);
float ss = op::common_kunlun::reduce_op::sumSquaredF32(x_local, curr_nn);
atomic
_add
(&sm_output[curr_m - m_start], ss);
atomic
AddF32
(&sm_output[curr_m - m_start], ss);
}
}
mfence();
mfence();
sync_cluster();
sync_cluster();
...
@@ -103,7 +103,7 @@ __global__ void rms_norm(float *y, long stride_y, float *x, long stride_x, float
...
@@ -103,7 +103,7 @@ __global__ void rms_norm(float *y, long stride_y, float *x, long stride_x, float
float ss = SM2REG_atomic(sm_output + m - m_start);
float ss = SM2REG_atomic(sm_output + m - m_start);
float rms = 1.0f / sqrt(ss / n + epsilon);
float rms = 1.0f / sqrt(ss / n + epsilon);
element
M
ul(x_local, w_local, y_local, nn, rms);
element
_m
ul(x_local, w_local, y_local, nn, rms);
mfence();
mfence();
auto y_ptr = y + m * stride_y + n_start;
auto y_ptr = y + m * stride_y + n_start;
...
@@ -116,8 +116,8 @@ __global__ void rms_norm(float *y, long stride_y, float *x, long stride_x, float
...
@@ -116,8 +116,8 @@ __global__ void rms_norm(float *y, long stride_y, float *x, long stride_x, float
}
}
}
}
void rms
_n
orm
_f
32(void *y, long stride_y, const void *x, long stride_x, const void *w, int m, int n, float epsilon, XPUStream stream) {
void rms
N
orm
F
32(void *y, long stride_y, const void *x, long stride_x, const void *w, int m, int n, float epsilon, XPUStream stream) {
rms_norm<<<8, 32, stream>>>((float *)y, stride_y, (float *)x, stride_x, (float *)w, m, n, epsilon);
rms_norm
_f32
<<<8, 32, stream>>>((float *)y, stride_y, (
const
float *)x, stride_x, (
const
float *)w, m, n, epsilon);
}
}
#endif
#endif
src/infiniop/ops/rms_norm/kunlun/rms_norm_kunlun.cc
View file @
650f3057
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#include <memory>
#include <memory>
#include <stdint.h>
#include <stdint.h>
void
rms
_n
orm
_f
32
(
void
*
y
,
long
stride_y
,
const
void
*
x
,
long
stride_x
,
const
void
*
w
,
int
m
,
int
n
,
float
epsilon
,
XPUStream
stream
);
void
rms
N
orm
F
32
(
void
*
y
,
long
stride_y
,
const
void
*
x
,
long
stride_x
,
const
void
*
w
,
int
m
,
int
n
,
float
epsilon
,
XPUStream
stream
);
namespace
op
::
rms_norm
::
kunlun
{
namespace
op
::
rms_norm
::
kunlun
{
...
@@ -53,7 +53,7 @@ infiniStatus_t launchKernel(
...
@@ -53,7 +53,7 @@ infiniStatus_t launchKernel(
kunlunStream_t
stream
)
{
kunlunStream_t
stream
)
{
if
(
atype
==
INFINI_DTYPE_F32
&&
wtype
==
INFINI_DTYPE_F32
)
{
if
(
atype
==
INFINI_DTYPE_F32
&&
wtype
==
INFINI_DTYPE_F32
)
{
rms
_n
orm
_f
32
(
y
,
static_cast
<
long
>
(
stride_y
),
x
,
static_cast
<
long
>
(
stride_x
),
w
,
m
,
n
,
epsilon
,
stream
);
rms
N
orm
F
32
(
y
,
static_cast
<
long
>
(
stride_y
),
x
,
static_cast
<
long
>
(
stride_x
),
w
,
m
,
n
,
epsilon
,
stream
);
}
else
{
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment