Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
8727bdcf
Commit
8727bdcf
authored
May 20, 2025
by
zhangyue
Browse files
rename private vars
parent
30bf79f1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
93 additions
and
116 deletions
+93
-116
src/infiniop/ops/rope/ascend/rope_ascend_kernel.cpp
src/infiniop/ops/rope/ascend/rope_ascend_kernel.cpp
+93
-116
No files found.
src/infiniop/ops/rope/ascend/rope_ascend_kernel.cpp
View file @
8727bdcf
...
...
@@ -32,21 +32,21 @@ private:
private:
TPipe
pipe
;
TQue
<
QuePosition
::
VECIN
,
BUFFER_NUM
>
in
Q
ue
;
TQue
<
QuePosition
::
VECIN
,
BUFFER_NUM
>
sin
Q
ue
;
TQue
<
QuePosition
::
VECIN
,
BUFFER_NUM
>
cos
Q
ue
;
TQue
<
QuePosition
::
VECOUT
,
BUFFER_NUM
>
out
Q
ue
;
TBuf
<
TPosition
::
VECCALC
>
tmp
OddB
uf
;
TBuf
<
TPosition
::
VECCALC
>
tmp
E
ven
B
uf
;
TBuf
<
TPosition
::
VECCALC
>
tmp
OddB
uf1
;
TBuf
<
TPosition
::
VECCALC
>
tmp
OddB
uf2
;
TBuf
<
TPosition
::
VECCALC
>
tmp
E
ven
B
uf1
;
TBuf
<
TPosition
::
VECCALC
>
tmp
E
ven
B
uf2
;
TQue
<
QuePosition
::
VECIN
,
BUFFER_NUM
>
_
in
_q
ue
;
TQue
<
QuePosition
::
VECIN
,
BUFFER_NUM
>
_
sin
_q
ue
;
TQue
<
QuePosition
::
VECIN
,
BUFFER_NUM
>
_
cos
_q
ue
;
TQue
<
QuePosition
::
VECOUT
,
BUFFER_NUM
>
_
out
_q
ue
;
TBuf
<
TPosition
::
VECCALC
>
_
tmp
_odd_b
uf
;
TBuf
<
TPosition
::
VECCALC
>
_
tmp
_e
ven
_b
uf
;
TBuf
<
TPosition
::
VECCALC
>
_
tmp
_odd_b
uf1
;
TBuf
<
TPosition
::
VECCALC
>
_
tmp
_odd_b
uf2
;
TBuf
<
TPosition
::
VECCALC
>
_
tmp
_e
ven
_b
uf1
;
TBuf
<
TPosition
::
VECCALC
>
_
tmp
_e
ven
_b
uf2
;
GlobalTensor
<
T
>
xGm
,
yG
m
;
GlobalTensor
<
U
>
pG
m
;
GlobalTensor
<
T
>
sin
G
m
;
GlobalTensor
<
T
>
cos
G
m
;
GlobalTensor
<
T
>
_x_gm
,
_y_g
m
;
GlobalTensor
<
U
>
_p_g
m
;
GlobalTensor
<
T
>
_
sin
_g
m
;
GlobalTensor
<
T
>
_
cos
_g
m
;
size_t
_block_idx
;
size_t
_tile_len
;
...
...
@@ -83,57 +83,57 @@ __aicore__ inline void RoPEKernel<T, U>::init(GM_ADDR y,
_block_idx
=
GetBlockIdx
();
// Init global buffer
xG
m
.
SetGlobalBuffer
((
__gm__
T
*
)
x
);
pG
m
.
SetGlobalBuffer
((
__gm__
U
*
)
pos
);
sin
G
m
.
SetGlobalBuffer
((
__gm__
T
*
)
sin
);
cos
G
m
.
SetGlobalBuffer
((
__gm__
T
*
)
cos
);
yG
m
.
SetGlobalBuffer
((
__gm__
T
*
)
y
);
_x_g
m
.
SetGlobalBuffer
((
__gm__
T
*
)
x
);
_p_g
m
.
SetGlobalBuffer
((
__gm__
U
*
)
pos
);
_
sin
_g
m
.
SetGlobalBuffer
((
__gm__
T
*
)
sin
);
_
cos
_g
m
.
SetGlobalBuffer
((
__gm__
T
*
)
cos
);
_y_g
m
.
SetGlobalBuffer
((
__gm__
T
*
)
y
);
// Init Queue buffer
pipe
.
InitBuffer
(
in
Q
ue
,
BUFFER_NUM
,
_copy_len
*
sizeof
(
T
));
pipe
.
InitBuffer
(
out
Q
ue
,
BUFFER_NUM
,
_tile_len
*
sizeof
(
T
));
pipe
.
InitBuffer
(
sin
Q
ue
,
BUFFER_NUM
,
_half_copy_len
*
sizeof
(
T
));
pipe
.
InitBuffer
(
cos
Q
ue
,
BUFFER_NUM
,
_half_copy_len
*
sizeof
(
T
));
pipe
.
InitBuffer
(
tmp
OddB
uf
,
_tile_len
/
2
*
sizeof
(
T
));
pipe
.
InitBuffer
(
tmp
E
ven
B
uf
,
_tile_len
/
2
*
sizeof
(
T
));
pipe
.
InitBuffer
(
tmp
OddB
uf1
,
_tile_len
/
2
*
sizeof
(
T
));
pipe
.
InitBuffer
(
tmp
OddB
uf2
,
_tile_len
/
2
*
sizeof
(
T
));
pipe
.
InitBuffer
(
tmp
E
ven
B
uf1
,
_tile_len
/
2
*
sizeof
(
T
));
pipe
.
InitBuffer
(
tmp
E
ven
B
uf2
,
_tile_len
/
2
*
sizeof
(
T
));
pipe
.
InitBuffer
(
_
in
_q
ue
,
BUFFER_NUM
,
_copy_len
*
sizeof
(
T
));
pipe
.
InitBuffer
(
_
out
_q
ue
,
BUFFER_NUM
,
_tile_len
*
sizeof
(
T
));
pipe
.
InitBuffer
(
_
sin
_q
ue
,
BUFFER_NUM
,
_half_copy_len
*
sizeof
(
T
));
pipe
.
InitBuffer
(
_
cos
_q
ue
,
BUFFER_NUM
,
_half_copy_len
*
sizeof
(
T
));
pipe
.
InitBuffer
(
_
tmp
_odd_b
uf
,
_tile_len
/
2
*
sizeof
(
T
));
pipe
.
InitBuffer
(
_
tmp
_e
ven
_b
uf
,
_tile_len
/
2
*
sizeof
(
T
));
pipe
.
InitBuffer
(
_
tmp
_odd_b
uf1
,
_tile_len
/
2
*
sizeof
(
T
));
pipe
.
InitBuffer
(
_
tmp
_odd_b
uf2
,
_tile_len
/
2
*
sizeof
(
T
));
pipe
.
InitBuffer
(
_
tmp
_e
ven
_b
uf1
,
_tile_len
/
2
*
sizeof
(
T
));
pipe
.
InitBuffer
(
_
tmp
_e
ven
_b
uf2
,
_tile_len
/
2
*
sizeof
(
T
));
}
template
<
typename
T
,
typename
U
>
__aicore__
inline
void
RoPEKernel
<
T
,
U
>::
copyIn
(
size_t
i
)
{
LocalTensor
<
T
>
input
U
b
=
in
Q
ue
.
AllocTensor
<
T
>
();
LocalTensor
<
T
>
sin
U
b
=
sin
Q
ue
.
AllocTensor
<
T
>
();
LocalTensor
<
T
>
cos
U
b
=
cos
Q
ue
.
AllocTensor
<
T
>
();
LocalTensor
<
T
>
input
_u
b
=
_
in
_q
ue
.
AllocTensor
<
T
>
();
LocalTensor
<
T
>
sin
_u
b
=
_
sin
_q
ue
.
AllocTensor
<
T
>
();
LocalTensor
<
T
>
cos
_u
b
=
_
cos
_q
ue
.
AllocTensor
<
T
>
();
// Get idx of current tile in total input
auto
idx
=
i
*
_st_xnt
+
_block_idx
*
_st_xnh
;
// Copy tile current tile into UB
DataCopy
(
input
U
b
,
xG
m
[
idx
],
_copy_len
);
DataCopy
(
input
_u
b
,
_x_g
m
[
idx
],
_copy_len
);
// Copy sin cos tile
auto
pos_idx
=
pG
m
(
i
);
DataCopy
(
sin
U
b
,
sin
G
m
[
pos_idx
*
_tile_len
/
2
],
_half_copy_len
);
DataCopy
(
cos
U
b
,
cos
G
m
[
pos_idx
*
_tile_len
/
2
],
_half_copy_len
);
auto
pos_idx
=
_p_g
m
(
i
);
DataCopy
(
sin
_u
b
,
_
sin
_g
m
[
pos_idx
*
_tile_len
/
2
],
_half_copy_len
);
DataCopy
(
cos
_u
b
,
_
cos
_g
m
[
pos_idx
*
_tile_len
/
2
],
_half_copy_len
);
// Push in operands
in
Q
ue
.
EnQue
(
input
U
b
);
sin
Q
ue
.
EnQue
(
sin
U
b
);
cos
Q
ue
.
EnQue
(
cos
U
b
);
_
in
_q
ue
.
EnQue
(
input
_u
b
);
_
sin
_q
ue
.
EnQue
(
sin
_u
b
);
_
cos
_q
ue
.
EnQue
(
cos
_u
b
);
}
template
<
typename
T
,
typename
U
>
__aicore__
inline
void
RoPEKernel
<
T
,
U
>::
compute
(
size_t
i
)
{
LocalTensor
<
T
>
input
U
b
=
in
Q
ue
.
DeQue
<
T
>
();
LocalTensor
<
T
>
sin
U
b
=
sin
Q
ue
.
DeQue
<
T
>
();
LocalTensor
<
T
>
cos
U
b
=
cos
Q
ue
.
DeQue
<
T
>
();
LocalTensor
<
T
>
output
U
b
=
out
Q
ue
.
AllocTensor
<
T
>
();
LocalTensor
<
T
>
input
_u
b
=
_
in
_q
ue
.
DeQue
<
T
>
();
LocalTensor
<
T
>
sin
_u
b
=
_
sin
_q
ue
.
DeQue
<
T
>
();
LocalTensor
<
T
>
cos
_u
b
=
_
cos
_q
ue
.
DeQue
<
T
>
();
LocalTensor
<
T
>
output
_u
b
=
_
out
_q
ue
.
AllocTensor
<
T
>
();
LocalTensor
<
T
>
tmp
O
dd
=
tmp
OddB
uf
.
Get
<
T
>
();
LocalTensor
<
T
>
tmp
E
ven
=
tmp
E
ven
B
uf
.
Get
<
T
>
();
LocalTensor
<
T
>
tmp
O
dd1
=
tmp
OddB
uf1
.
Get
<
T
>
();
LocalTensor
<
T
>
tmp
O
dd2
=
tmp
OddB
uf2
.
Get
<
T
>
();
LocalTensor
<
T
>
tmp
E
ven1
=
tmp
E
ven
B
uf1
.
Get
<
T
>
();
LocalTensor
<
T
>
tmp
E
ven2
=
tmp
E
ven
B
uf2
.
Get
<
T
>
();
LocalTensor
<
T
>
tmp
_o
dd
=
_
tmp
_odd_b
uf
.
Get
<
T
>
();
LocalTensor
<
T
>
tmp
_e
ven
=
_
tmp
_e
ven
_b
uf
.
Get
<
T
>
();
LocalTensor
<
T
>
tmp
_o
dd1
=
_
tmp
_odd_b
uf1
.
Get
<
T
>
();
LocalTensor
<
T
>
tmp
_o
dd2
=
_
tmp
_odd_b
uf2
.
Get
<
T
>
();
LocalTensor
<
T
>
tmp
_e
ven1
=
_
tmp
_e
ven
_b
uf1
.
Get
<
T
>
();
LocalTensor
<
T
>
tmp
_e
ven2
=
_
tmp
_e
ven
_b
uf2
.
Get
<
T
>
();
// separate odd and even bit elements
uint64_t
rsvdCnt
=
0
;
...
...
@@ -143,43 +143,43 @@ __aicore__ inline void RoPEKernel<T, U>::compute(size_t i) {
8
,
8
,
};
GatherMask
<
T
>
(
tmp
O
dd
,
input
U
b
,
1
,
false
,
0
,
gMaskParams
,
rsvdCnt
);
GatherMask
<
T
>
(
tmp
E
ven
,
input
U
b
,
2
,
false
,
0
,
gMaskParams
,
rsvdCnt
);
GatherMask
<
T
>
(
tmp
_o
dd
,
input
_u
b
,
1
,
false
,
0
,
gMaskParams
,
rsvdCnt
);
GatherMask
<
T
>
(
tmp
_e
ven
,
input
_u
b
,
2
,
false
,
0
,
gMaskParams
,
rsvdCnt
);
PipeBarrier
<
PIPE_V
>
();
// compute odd bit elements
// y_odd = x_odd * cos - x_even * sin
Mul
<
T
>
(
tmp
O
dd1
,
tmp
O
dd
,
cos
U
b
,
_tile_len
/
2
);
Mul
<
T
>
(
tmp
O
dd2
,
tmp
E
ven
,
sin
U
b
,
_tile_len
/
2
);
Mul
<
T
>
(
tmp
_o
dd1
,
tmp
_o
dd
,
cos
_u
b
,
_tile_len
/
2
);
Mul
<
T
>
(
tmp
_o
dd2
,
tmp
_e
ven
,
sin
_u
b
,
_tile_len
/
2
);
PipeBarrier
<
PIPE_V
>
();
Sub
<
T
>
(
tmp
O
dd1
,
tmp
O
dd1
,
tmp
O
dd2
,
_tile_len
/
2
);
Sub
<
T
>
(
tmp
_o
dd1
,
tmp
_o
dd1
,
tmp
_o
dd2
,
_tile_len
/
2
);
// compute even bit elements
// y_even = x_odd * sin + x_even * cos
Mul
<
T
>
(
tmp
E
ven1
,
tmp
O
dd
,
sin
U
b
,
_tile_len
/
2
);
Mul
<
T
>
(
tmp
E
ven2
,
tmp
E
ven
,
cos
U
b
,
_tile_len
/
2
);
Mul
<
T
>
(
tmp
_e
ven1
,
tmp
_o
dd
,
sin
_u
b
,
_tile_len
/
2
);
Mul
<
T
>
(
tmp
_e
ven2
,
tmp
_e
ven
,
cos
_u
b
,
_tile_len
/
2
);
PipeBarrier
<
PIPE_V
>
();
Add
<
T
>
(
tmp
E
ven1
,
tmp
E
ven1
,
tmp
E
ven2
,
_tile_len
/
2
);
Add
<
T
>
(
tmp
_e
ven1
,
tmp
_e
ven1
,
tmp
_e
ven2
,
_tile_len
/
2
);
// combine odd and even bit elements
for
(
uint32_t
j
=
0
;
j
<
_tile_len
/
2
;
j
+=
1
)
{
output
U
b
(
j
*
2
)
=
tmp
O
dd1
(
j
);
output
U
b
(
j
*
2
+
1
)
=
tmp
E
ven1
(
j
);
output
_u
b
(
j
*
2
)
=
tmp
_o
dd1
(
j
);
output
_u
b
(
j
*
2
+
1
)
=
tmp
_e
ven1
(
j
);
}
out
Q
ue
.
EnQue
<
T
>
(
output
U
b
);
in
Q
ue
.
FreeTensor
(
input
U
b
);
sin
Q
ue
.
FreeTensor
(
sin
U
b
);
cos
Q
ue
.
FreeTensor
(
cos
U
b
);
_
out
_q
ue
.
EnQue
<
T
>
(
output
_u
b
);
_
in
_q
ue
.
FreeTensor
(
input
_u
b
);
_
sin
_q
ue
.
FreeTensor
(
sin
_u
b
);
_
cos
_q
ue
.
FreeTensor
(
cos
_u
b
);
}
template
<
typename
T
,
typename
U
>
__aicore__
inline
void
RoPEKernel
<
T
,
U
>::
copyOut
(
size_t
i
)
{
LocalTensor
<
T
>
output
U
b
=
out
Q
ue
.
DeQue
<
T
>
();
LocalTensor
<
T
>
output
_u
b
=
_
out
_q
ue
.
DeQue
<
T
>
();
auto
idy
=
i
*
_st_ynt
+
_block_idx
*
_st_ynh
;
DataCopyExtParams
params
=
{
1
,
static_cast
<
uint32_t
>
(
_tile_len
*
sizeof
(
T
)),
0
,
0
,
0
};
DataCopyPad
(
yG
m
[
idy
],
output
U
b
,
params
);
out
Q
ue
.
FreeTensor
(
output
U
b
);
DataCopyPad
(
_y_g
m
[
idy
],
output
_u
b
,
params
);
_
out
_q
ue
.
FreeTensor
(
output
_u
b
);
}
template
<
typename
T
,
typename
U
>
...
...
@@ -192,56 +192,30 @@ __aicore__ inline void RoPEKernel<T, U>::process(size_t seq_len) {
}
}
#define ROPE_KERNEL(TYPE, POSTYPE) \
switch (POSTYPE) { \
case INFINI_DTYPE_I8: { \
RoPEKernel<TYPE, int8_t> op; \
op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.process(seq_len); \
break; \
} \
case INFINI_DTYPE_I16: { \
RoPEKernel<TYPE, int16_t> op; \
op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.process(seq_len); \
break; \
} \
case INFINI_DTYPE_I32: { \
RoPEKernel<TYPE, int32_t> op; \
op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.process(seq_len); \
break; \
} \
case INFINI_DTYPE_I64: { \
RoPEKernel<TYPE, int64_t> op; \
op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.process(seq_len); \
break; \
} \
case INFINI_DTYPE_U8: { \
RoPEKernel<TYPE, uint8_t> op; \
op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.process(seq_len); \
break; \
} \
case INFINI_DTYPE_U16: { \
RoPEKernel<TYPE, uint16_t> op; \
op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.process(seq_len); \
break; \
} \
case INFINI_DTYPE_U32: { \
RoPEKernel<TYPE, uint32_t> op; \
op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.process(seq_len); \
break; \
} \
case INFINI_DTYPE_U64: { \
RoPEKernel<TYPE, uint64_t> op; \
op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.process(seq_len); \
break; \
} \
#define ROPE_KERNEL_INIT_ARGS y, x, pos, sin, cos, dhead, \
y_stride_seqlen, y_stride_nhead, \
x_stride_seqlen, x_stride_nhead
#define CASE_POSTYPE(POS_TYPE_ENUM, TYPE, POS_T) \
case POS_TYPE_ENUM: { \
RoPEKernel<TYPE, POS_T> op; \
op.init(ROPE_KERNEL_INIT_ARGS); \
op.process(seq_len); \
break; \
}
#define ROPE_KERNEL(TYPE, POSTYPE) \
switch (POSTYPE) { \
CASE_POSTYPE(INFINI_DTYPE_I8, TYPE, int8_t) \
CASE_POSTYPE(INFINI_DTYPE_I16, TYPE, int16_t) \
CASE_POSTYPE(INFINI_DTYPE_I32, TYPE, int32_t) \
CASE_POSTYPE(INFINI_DTYPE_I64, TYPE, int64_t) \
CASE_POSTYPE(INFINI_DTYPE_U8, TYPE, uint8_t) \
CASE_POSTYPE(INFINI_DTYPE_U16, TYPE, uint16_t) \
CASE_POSTYPE(INFINI_DTYPE_U32, TYPE, uint32_t) \
CASE_POSTYPE(INFINI_DTYPE_U64, TYPE, uint64_t) \
default: \
break; \
}
#define DEFINE_ROPE_KERNEL(KERNEL_NAME, TYPE) \
...
...
@@ -264,6 +238,9 @@ DEFINE_ROPE_KERNEL(rope_kernel_float, float)
DEFINE_ROPE_KERNEL
(
rope_kernel_half
,
half
)
#undef DEFINE_ROPE_KERNEL
#undef ROPE_KERNEL
#undef CASE_POSTYPE
#undef ROPE_KERNEL_INIT_ARGS
extern
"C"
infiniStatus_t
rope_kernel_launch
(
void
*
y
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment