Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
2c0e9a6e
Commit
2c0e9a6e
authored
Apr 29, 2025
by
zhangyunze
Committed by
zhangyue
May 15, 2025
Browse files
fix:解除kernel中对postype的限制
parent
bbb0105b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
81 additions
and
29 deletions
+81
-29
src/infiniop/ops/rope/ascend/rope_aclnn.cc
src/infiniop/ops/rope/ascend/rope_aclnn.cc
+3
-3
src/infiniop/ops/rope/ascend/rope_kernel.cpp
src/infiniop/ops/rope/ascend/rope_kernel.cpp
+78
-26
No files found.
src/infiniop/ops/rope/ascend/rope_aclnn.cc
View file @
2c0e9a6e
...
...
@@ -33,6 +33,7 @@ extern "C" infiniStatus_t rope_kernel_launch(
int32_t
nhead
,
int32_t
dhead
,
int32_t
data_type
,
int32_t
pos_type
,
int32_t
y_stride_seqlen
,
int32_t
y_stride_nhead
,
int32_t
x_stride_seqlen
,
...
...
@@ -48,17 +49,16 @@ infiniStatus_t Descriptor::calculate(
const
void
*
sin_table
,
const
void
*
cos_table
,
void
*
stream
)
const
{
// TODO: 是否有可能解除这个判断
CHECK_DTYPE
(
_info
.
pos_type
,
INFINI_DTYPE_U32
);
CHECK_DTYPE
(
_info
.
data_type
,
INFINI_DTYPE_F32
,
INFINI_DTYPE_F16
);
int32_t
seq_len
=
_info
.
seqlen
;
int32_t
nhead
=
_info
.
nhead
;
int32_t
dhead
=
_info
.
dhead
;
int32_t
data_type
=
_info
.
data_type
;
int32_t
pos_type
=
_info
.
pos_type
;
int32_t
y_stride_seqlen
=
_info
.
y_stride_seqlen
;
int32_t
y_stride_nhead
=
_info
.
y_stride_nhead
;
int32_t
x_stride_seqlen
=
_info
.
x_stride_seqlen
;
int32_t
x_stride_nhead
=
_info
.
x_stride_nhead
;
return
rope_kernel_launch
(
y
,
(
void
*
)
x
,
(
void
*
)
pos_ids
,
(
void
*
)
sin_table
,
(
void
*
)
cos_table
,
seq_len
,
nhead
,
dhead
,
data_type
,
y_stride_seqlen
,
y_stride_nhead
,
x_stride_seqlen
,
x_stride_nhead
,
stream
);
return
rope_kernel_launch
(
y
,
(
void
*
)
x
,
(
void
*
)
pos_ids
,
(
void
*
)
sin_table
,
(
void
*
)
cos_table
,
seq_len
,
nhead
,
dhead
,
data_type
,
pos_type
,
y_stride_seqlen
,
y_stride_nhead
,
x_stride_seqlen
,
x_stride_nhead
,
stream
);
}
}
// namespace op::rope::ascend
src/infiniop/ops/rope/ascend/rope_kernel.cpp
View file @
2c0e9a6e
...
...
@@ -7,7 +7,7 @@ constexpr int32_t BYTE_ALIGN = 32;
using
namespace
AscendC
;
template
<
typename
T
>
template
<
typename
T
,
typename
U
>
class
RoPEKernel
{
public:
__aicore__
inline
RoPEKernel
()
{}
...
...
@@ -43,7 +43,7 @@ private:
TBuf
<
TPosition
::
VECCALC
>
tmpEvenBuf2
;
GlobalTensor
<
T
>
xGm
,
yGm
;
GlobalTensor
<
uint32_t
>
pGm
;
GlobalTensor
<
U
>
pGm
;
GlobalTensor
<
T
>
sinGm
;
GlobalTensor
<
T
>
cosGm
;
...
...
@@ -60,11 +60,11 @@ private:
int32_t
st_xnh_
;
};
template
<
typename
T
>
__aicore__
inline
void
RoPEKernel
<
T
>::
Init
(
GM_ADDR
y
,
GM_ADDR
x
,
GM_ADDR
pos
,
GM_ADDR
sin
,
GM_ADDR
cos
,
int32_t
dh
,
int32_t
st_ynt
,
int32_t
st_ynh
,
int32_t
st_xnt
,
int32_t
st_xnh
)
{
template
<
typename
T
,
typename
U
>
__aicore__
inline
void
RoPEKernel
<
T
,
U
>::
Init
(
GM_ADDR
y
,
GM_ADDR
x
,
GM_ADDR
pos
,
GM_ADDR
sin
,
GM_ADDR
cos
,
int32_t
dh
,
int32_t
st_ynt
,
int32_t
st_ynh
,
int32_t
st_xnt
,
int32_t
st_xnh
)
{
this
->
_tile_len
=
dh
;
this
->
st_ynt_
=
st_ynt
;
this
->
st_ynh_
=
st_ynh
;
...
...
@@ -77,7 +77,7 @@ __aicore__ inline void RoPEKernel<T>::Init(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM
// Init global buffer
xGm
.
SetGlobalBuffer
((
__gm__
T
*
)
x
);
pGm
.
SetGlobalBuffer
(
reinterpret_cast
<
__gm__
uint32_t
*>
(
pos
)
)
;
pGm
.
SetGlobalBuffer
(
(
__gm__
U
*
)
pos
);
sinGm
.
SetGlobalBuffer
((
__gm__
T
*
)
sin
);
cosGm
.
SetGlobalBuffer
((
__gm__
T
*
)
cos
);
yGm
.
SetGlobalBuffer
((
__gm__
T
*
)
y
);
...
...
@@ -95,8 +95,8 @@ __aicore__ inline void RoPEKernel<T>::Init(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM
pipe
.
InitBuffer
(
tmpEvenBuf2
,
_tile_len
/
2
*
sizeof
(
T
));
}
template
<
typename
T
>
__aicore__
inline
void
RoPEKernel
<
T
>::
CopyIn
(
int32_t
i
)
{
template
<
typename
T
,
typename
U
>
__aicore__
inline
void
RoPEKernel
<
T
,
U
>::
CopyIn
(
int32_t
i
)
{
LocalTensor
<
T
>
inputUb
=
inQue
.
AllocTensor
<
T
>
();
LocalTensor
<
T
>
sinUb
=
sinQue
.
AllocTensor
<
T
>
();
LocalTensor
<
T
>
cosUb
=
cosQue
.
AllocTensor
<
T
>
();
...
...
@@ -114,8 +114,8 @@ __aicore__ inline void RoPEKernel<T>::CopyIn(int32_t i) {
cosQue
.
EnQue
(
cosUb
);
}
template
<
typename
T
>
__aicore__
inline
void
RoPEKernel
<
T
>::
Compute
(
int32_t
i
)
{
template
<
typename
T
,
typename
U
>
__aicore__
inline
void
RoPEKernel
<
T
,
U
>::
Compute
(
int32_t
i
)
{
LocalTensor
<
T
>
inputUb
=
inQue
.
DeQue
<
T
>
();
LocalTensor
<
T
>
sinUb
=
sinQue
.
DeQue
<
T
>
();
LocalTensor
<
T
>
cosUb
=
cosQue
.
DeQue
<
T
>
();
...
...
@@ -166,8 +166,8 @@ __aicore__ inline void RoPEKernel<T>::Compute(int32_t i) {
cosQue
.
FreeTensor
(
cosUb
);
}
template
<
typename
T
>
__aicore__
inline
void
RoPEKernel
<
T
>::
CopyOut
(
int32_t
i
)
{
template
<
typename
T
,
typename
U
>
__aicore__
inline
void
RoPEKernel
<
T
,
U
>::
CopyOut
(
int32_t
i
)
{
LocalTensor
<
T
>
outputUb
=
outQue
.
DeQue
<
T
>
();
auto
idy
=
i
*
st_ynt_
+
_block_idx
*
st_ynh_
;
DataCopyExtParams
params
=
{
1
,
static_cast
<
uint32_t
>
(
_tile_len
*
sizeof
(
T
)),
0
,
0
,
0
};
...
...
@@ -175,8 +175,8 @@ __aicore__ inline void RoPEKernel<T>::CopyOut(int32_t i) {
outQue
.
FreeTensor
(
outputUb
);
}
template
<
typename
T
>
__aicore__
inline
void
RoPEKernel
<
T
>::
Process
(
int32_t
nt
)
{
template
<
typename
T
,
typename
U
>
__aicore__
inline
void
RoPEKernel
<
T
,
U
>::
Process
(
int32_t
nt
)
{
for
(
int32_t
i
=
0
;
i
<
nt
;
++
i
)
{
CopyIn
(
i
);
...
...
@@ -185,22 +185,73 @@ __aicore__ inline void RoPEKernel<T>::Process(int32_t nt) {
}
}
#define ROPE_KERNEL(TYPE, POSTYPE) \
switch (POSTYPE) { \
case 3: { \
RoPEKernel<TYPE, int8_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
break; \
} \
case 4: { \
RoPEKernel<TYPE, int16_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
break; \
} \
case 5: { \
RoPEKernel<TYPE, int32_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
break; \
} \
case 6: { \
RoPEKernel<TYPE, int64_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
break; \
} \
case 7: { \
RoPEKernel<TYPE, uint8_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
break; \
} \
case 8: { \
RoPEKernel<TYPE, uint16_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
break; \
} \
case 9: { \
RoPEKernel<TYPE, uint32_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
break; \
} \
case 10: { \
RoPEKernel<TYPE, uint64_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
break; \
} \
}
__global__
__aicore__
void
rope_f16_kernel
(
GM_ADDR
y
,
GM_ADDR
x
,
GM_ADDR
pos
,
GM_ADDR
sin
,
GM_ADDR
cos
,
int32_t
seq_len
,
int32_t
dhead
,
int32_t
y_stride_seqlen
,
int32_t
y_stride_nhead
,
int32_t
x_stride_seqlen
,
int32_t
x_stride_nhead
)
{
RoPEKernel
<
half
>
op
;
op
.
Init
(
y
,
x
,
pos
,
sin
,
cos
,
dhead
,
y_stride_seqlen
,
y_stride_nhead
,
x_stride_seqlen
,
x_stride_nhead
);
op
.
Process
(
seq_len
);
int32_t
x_stride_seqlen
,
int32_t
x_stride_nhead
,
int32_t
pos_type
){
ROPE_KERNEL
(
half
,
pos_type
)
}
__global__
__aicore__
void
rope_f32_kernel
(
GM_ADDR
y
,
GM_ADDR
x
,
GM_ADDR
pos
,
GM_ADDR
sin
,
GM_ADDR
cos
,
int32_t
seq_len
,
int32_t
dhead
,
int32_t
y_stride_seqlen
,
int32_t
y_stride_nhead
,
int32_t
x_stride_seqlen
,
int32_t
x_stride_nhead
)
{
RoPEKernel
<
float
>
op
;
op
.
Init
(
y
,
x
,
pos
,
sin
,
cos
,
dhead
,
y_stride_seqlen
,
y_stride_nhead
,
x_stride_seqlen
,
x_stride_nhead
);
op
.
Process
(
seq_len
);
int32_t
x_stride_seqlen
,
int32_t
x_stride_nhead
,
int32_t
pos_type
)
{
ROPE_KERNEL
(
float
,
pos_type
)
}
extern
"C"
infiniStatus_t
rope_kernel_launch
(
void
*
y
,
...
...
@@ -212,6 +263,7 @@ extern "C" infiniStatus_t rope_kernel_launch(void *y,
int32_t
nhead
,
int32_t
dhead
,
int32_t
data_type
,
int32_t
pos_type
,
int32_t
y_stride_seqlen
,
int32_t
y_stride_nhead
,
int32_t
x_stride_seqlen
,
...
...
@@ -219,10 +271,10 @@ extern "C" infiniStatus_t rope_kernel_launch(void *y,
void
*
stream
)
{
switch
(
data_type
)
{
case
12
:
// float16
rope_f16_kernel
<<<
nhead
,
nullptr
,
stream
>>>
(
y
,
x
,
pos
,
sin
,
cos
,
seq_len
,
dhead
,
y_stride_seqlen
,
y_stride_nhead
,
x_stride_seqlen
,
x_stride_nhead
);
rope_f16_kernel
<<<
nhead
,
nullptr
,
stream
>>>
(
y
,
x
,
pos
,
sin
,
cos
,
seq_len
,
dhead
,
y_stride_seqlen
,
y_stride_nhead
,
x_stride_seqlen
,
x_stride_nhead
,
pos_type
);
break
;
case
13
:
// float32
rope_f32_kernel
<<<
nhead
,
nullptr
,
stream
>>>
(
y
,
x
,
pos
,
sin
,
cos
,
seq_len
,
dhead
,
y_stride_seqlen
,
y_stride_nhead
,
x_stride_seqlen
,
x_stride_nhead
);
rope_f32_kernel
<<<
nhead
,
nullptr
,
stream
>>>
(
y
,
x
,
pos
,
sin
,
cos
,
seq_len
,
dhead
,
y_stride_seqlen
,
y_stride_nhead
,
x_stride_seqlen
,
x_stride_nhead
,
pos_type
);
break
;
default:
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment