Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
1be004cb
Unverified
Commit
1be004cb
authored
May 27, 2025
by
PanZezhong1725
Committed by
GitHub
May 27, 2025
Browse files
Merge pull request #203 from InfiniTensor/ascend-rope
feat: 添加昇腾rope算子
parents
5beab8c0
8727bdcf
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
434 additions
and
67 deletions
+434
-67
src/infiniop/devices/ascend/CMakeLists.txt
src/infiniop/devices/ascend/CMakeLists.txt
+1
-2
src/infiniop/devices/ascend/ascend_kernel_common.h
src/infiniop/devices/ascend/ascend_kernel_common.h
+12
-3
src/infiniop/ops/rope/ascend/rope_ascend.cc
src/infiniop/ops/rope/ascend/rope_ascend.cc
+50
-0
src/infiniop/ops/rope/ascend/rope_ascend.h
src/infiniop/ops/rope/ascend/rope_ascend.h
+25
-0
src/infiniop/ops/rope/ascend/rope_ascend_kernel.cpp
src/infiniop/ops/rope/ascend/rope_ascend_kernel.cpp
+280
-0
src/infiniop/ops/rope/operator.cc
src/infiniop/ops/rope/operator.cc
+11
-20
src/infiniop/ops/swiglu/ascend/swiglu_ascend.cc
src/infiniop/ops/swiglu/ascend/swiglu_ascend.cc
+0
-6
src/infiniop/ops/swiglu/ascend/swiglu_ascend.h
src/infiniop/ops/swiglu/ascend/swiglu_ascend.h
+6
-0
src/infiniop/ops/swiglu/ascend/swiglu_ascend_kernel.cpp
src/infiniop/ops/swiglu/ascend/swiglu_ascend_kernel.cpp
+46
-36
test/infiniop/rope.py
test/infiniop/rope.py
+3
-0
No files found.
src/infiniop/devices/ascend/CMakeLists.txt
View file @
1be004cb
...
@@ -23,10 +23,9 @@ include_directories(
...
@@ -23,10 +23,9 @@ include_directories(
${
CMAKE_SOURCE_DIR
}
/../../../../include/infiniop/
${
CMAKE_SOURCE_DIR
}
/../../../../include/infiniop/
)
)
ascendc_library
(
ascend_kernels STATIC
ascendc_library
(
ascend_kernels STATIC
../../ops/swiglu/ascend/swiglu_ascend_kernel.cpp
../../ops/swiglu/ascend/swiglu_ascend_kernel.cpp
#
../../ops/ro
tary_embedding/ascend/rotary_embedding
_kernel.cpp
../../ops/ro
pe/ascend/rope_ascend
_kernel.cpp
# ../../ops/random_sample/ascend/random_sample_kernel.cpp
# ../../ops/random_sample/ascend/random_sample_kernel.cpp
)
)
src/infiniop/devices/ascend/ascend_kernel_common.h
View file @
1be004cb
...
@@ -4,8 +4,17 @@
...
@@ -4,8 +4,17 @@
#include "../../../../include/infinicore.h"
#include "../../../../include/infinicore.h"
#include "kernel_operator.h"
#include "kernel_operator.h"
constexpr
int32_t
BLOCK_NUM
=
8
;
constexpr
size_t
BLOCK_NUM
=
8
;
constexpr
int32_t
BUFFER_NUM
=
2
;
constexpr
size_t
BUFFER_NUM
=
2
;
constexpr
int32_t
BYTE_ALIGN
=
32
;
constexpr
size_t
BYTE_ALIGN
=
32
;
template
<
typename
T
>
__aicore__
inline
size_t
alignTileLen
(
size_t
tile_len
,
size_t
byte_align
)
{
size_t
bytes
=
tile_len
*
sizeof
(
T
);
size_t
aligned_bytes
=
(
bytes
%
byte_align
==
0
)
?
bytes
:
(
bytes
+
(
byte_align
-
bytes
%
byte_align
));
return
aligned_bytes
/
sizeof
(
T
);
}
#endif
#endif
src/infiniop/ops/rope/ascend/rope_ascend.cc
0 → 100644
View file @
1be004cb
#include "rope_ascend.h"
#include "../../../devices/ascend/common_ascend.h"
namespace
op
::
rope
::
ascend
{
Descriptor
::~
Descriptor
()
=
default
;
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
x_desc
,
infiniopTensorDescriptor_t
pos_desc
,
infiniopTensorDescriptor_t
sin_desc
,
infiniopTensorDescriptor_t
cos_desc
)
{
auto
handle_ascned
=
reinterpret_cast
<
device
::
ascend
::
Handle
*>
(
handle
);
auto
result
=
RoPEInfo
::
createRoPEInfo
(
y_desc
,
x_desc
,
pos_desc
,
sin_desc
,
cos_desc
);
CHECK_RESULT
(
result
);
size_t
workspace_size
=
0
;
*
desc_ptr
=
new
Descriptor
(
std
::
move
(
result
.
take
()),
workspace_size
,
nullptr
,
handle_ascned
->
device
,
handle_ascned
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
const
void
*
x
,
const
void
*
pos_ids
,
const
void
*
sin_table
,
const
void
*
cos_table
,
void
*
stream
)
const
{
CHECK_DTYPE
(
_info
.
data_type
,
INFINI_DTYPE_F32
,
INFINI_DTYPE_F16
);
auto
data_type
=
_info
.
data_type
;
auto
pos_type
=
_info
.
pos_type
;
auto
seq_len
=
_info
.
seqlen
;
auto
nhead
=
_info
.
nhead
;
auto
dhead
=
_info
.
dhead
;
auto
y_stride_seqlen
=
_info
.
y_stride_seqlen
;
auto
y_stride_nhead
=
_info
.
y_stride_nhead
;
auto
x_stride_seqlen
=
_info
.
x_stride_seqlen
;
auto
x_stride_nhead
=
_info
.
x_stride_nhead
;
return
rope_kernel_launch
(
y
,
(
void
*
)
x
,
(
void
*
)
pos_ids
,
(
void
*
)
sin_table
,
(
void
*
)
cos_table
,
seq_len
,
nhead
,
dhead
,
data_type
,
pos_type
,
y_stride_seqlen
,
y_stride_nhead
,
x_stride_seqlen
,
x_stride_nhead
,
stream
);
}
}
// namespace op::rope::ascend
src/infiniop/ops/rope/ascend/rope_ascend.h
0 → 100644
View file @
1be004cb
#ifndef __ACLNN_ROPE_H__
#define __ACLNN_ROPE_H__
#include "../rope.h"
extern
"C"
infiniStatus_t
rope_kernel_launch
(
void
*
y
,
void
*
x
,
void
*
pos
,
void
*
sin
,
void
*
cos
,
size_t
seq_len
,
size_t
nhead
,
size_t
dhead
,
infiniDtype_t
data_type
,
infiniDtype_t
pos_type
,
ptrdiff_t
y_stride_seqlen
,
ptrdiff_t
y_stride_nhead
,
ptrdiff_t
x_stride_seqlen
,
ptrdiff_t
x_stride_nhead
,
void
*
stream
);
DESCRIPTOR
(
ascend
)
#endif // __ACLNN_ROPE_H__
src/infiniop/ops/rope/ascend/rope_ascend_kernel.cpp
0 → 100644
View file @
1be004cb
#include "../../../devices/ascend/ascend_kernel_common.h"
using
namespace
AscendC
;
template
<
typename
T
,
typename
U
>
class
RoPEKernel
{
public:
__aicore__
inline
RoPEKernel
()
{}
// Init op
// pos position vector
// x input tensor
// y output tensor
// tensor shape [nt, nh, dh]
// make block_num = nh, tile_len = dh
__aicore__
inline
void
init
(
GM_ADDR
y
,
GM_ADDR
x
,
GM_ADDR
pos
,
GM_ADDR
sin
,
GM_ADDR
cos
,
size_t
dh
,
ptrdiff_t
st_ynt
,
ptrdiff_t
st_ynh
,
ptrdiff_t
st_xnt
,
ptrdiff_t
st_xnh
);
__aicore__
inline
void
process
(
size_t
seq_len
);
private:
// Copy a tile into UB
__aicore__
inline
void
copyIn
(
size_t
i
);
__aicore__
inline
void
compute
(
size_t
i
);
__aicore__
inline
void
copyOut
(
size_t
i
);
private:
TPipe
pipe
;
TQue
<
QuePosition
::
VECIN
,
BUFFER_NUM
>
_in_que
;
TQue
<
QuePosition
::
VECIN
,
BUFFER_NUM
>
_sin_que
;
TQue
<
QuePosition
::
VECIN
,
BUFFER_NUM
>
_cos_que
;
TQue
<
QuePosition
::
VECOUT
,
BUFFER_NUM
>
_out_que
;
TBuf
<
TPosition
::
VECCALC
>
_tmp_odd_buf
;
TBuf
<
TPosition
::
VECCALC
>
_tmp_even_buf
;
TBuf
<
TPosition
::
VECCALC
>
_tmp_odd_buf1
;
TBuf
<
TPosition
::
VECCALC
>
_tmp_odd_buf2
;
TBuf
<
TPosition
::
VECCALC
>
_tmp_even_buf1
;
TBuf
<
TPosition
::
VECCALC
>
_tmp_even_buf2
;
GlobalTensor
<
T
>
_x_gm
,
_y_gm
;
GlobalTensor
<
U
>
_p_gm
;
GlobalTensor
<
T
>
_sin_gm
;
GlobalTensor
<
T
>
_cos_gm
;
size_t
_block_idx
;
size_t
_tile_len
;
size_t
_copy_len
;
size_t
_half_copy_len
;
// stridey[_st_ynt, _st_ynh, 1]
ptrdiff_t
_st_ynt
;
ptrdiff_t
_st_ynh
;
// stridex[_st_xnt, _st_xnh, 1]
ptrdiff_t
_st_xnt
;
ptrdiff_t
_st_xnh
;
};
template
<
typename
T
,
typename
U
>
__aicore__
inline
void
RoPEKernel
<
T
,
U
>::
init
(
GM_ADDR
y
,
GM_ADDR
x
,
GM_ADDR
pos
,
GM_ADDR
sin
,
GM_ADDR
cos
,
size_t
dh
,
ptrdiff_t
st_ynt
,
ptrdiff_t
st_ynh
,
ptrdiff_t
st_xnt
,
ptrdiff_t
st_xnh
)
{
this
->
_tile_len
=
dh
;
this
->
_st_ynt
=
st_ynt
;
this
->
_st_ynh
=
st_ynh
;
this
->
_st_xnt
=
st_xnt
;
this
->
_st_xnh
=
st_xnh
;
_copy_len
=
alignTileLen
<
T
>
(
dh
,
BYTE_ALIGN
);
_half_copy_len
=
alignTileLen
<
T
>
(
dh
,
BYTE_ALIGN
);
_block_idx
=
GetBlockIdx
();
// Init global buffer
_x_gm
.
SetGlobalBuffer
((
__gm__
T
*
)
x
);
_p_gm
.
SetGlobalBuffer
((
__gm__
U
*
)
pos
);
_sin_gm
.
SetGlobalBuffer
((
__gm__
T
*
)
sin
);
_cos_gm
.
SetGlobalBuffer
((
__gm__
T
*
)
cos
);
_y_gm
.
SetGlobalBuffer
((
__gm__
T
*
)
y
);
// Init Queue buffer
pipe
.
InitBuffer
(
_in_que
,
BUFFER_NUM
,
_copy_len
*
sizeof
(
T
));
pipe
.
InitBuffer
(
_out_que
,
BUFFER_NUM
,
_tile_len
*
sizeof
(
T
));
pipe
.
InitBuffer
(
_sin_que
,
BUFFER_NUM
,
_half_copy_len
*
sizeof
(
T
));
pipe
.
InitBuffer
(
_cos_que
,
BUFFER_NUM
,
_half_copy_len
*
sizeof
(
T
));
pipe
.
InitBuffer
(
_tmp_odd_buf
,
_tile_len
/
2
*
sizeof
(
T
));
pipe
.
InitBuffer
(
_tmp_even_buf
,
_tile_len
/
2
*
sizeof
(
T
));
pipe
.
InitBuffer
(
_tmp_odd_buf1
,
_tile_len
/
2
*
sizeof
(
T
));
pipe
.
InitBuffer
(
_tmp_odd_buf2
,
_tile_len
/
2
*
sizeof
(
T
));
pipe
.
InitBuffer
(
_tmp_even_buf1
,
_tile_len
/
2
*
sizeof
(
T
));
pipe
.
InitBuffer
(
_tmp_even_buf2
,
_tile_len
/
2
*
sizeof
(
T
));
}
template
<
typename
T
,
typename
U
>
__aicore__
inline
void
RoPEKernel
<
T
,
U
>::
copyIn
(
size_t
i
)
{
LocalTensor
<
T
>
input_ub
=
_in_que
.
AllocTensor
<
T
>
();
LocalTensor
<
T
>
sin_ub
=
_sin_que
.
AllocTensor
<
T
>
();
LocalTensor
<
T
>
cos_ub
=
_cos_que
.
AllocTensor
<
T
>
();
// Get idx of current tile in total input
auto
idx
=
i
*
_st_xnt
+
_block_idx
*
_st_xnh
;
// Copy tile current tile into UB
DataCopy
(
input_ub
,
_x_gm
[
idx
],
_copy_len
);
// Copy sin cos tile
auto
pos_idx
=
_p_gm
(
i
);
DataCopy
(
sin_ub
,
_sin_gm
[
pos_idx
*
_tile_len
/
2
],
_half_copy_len
);
DataCopy
(
cos_ub
,
_cos_gm
[
pos_idx
*
_tile_len
/
2
],
_half_copy_len
);
// Push in operands
_in_que
.
EnQue
(
input_ub
);
_sin_que
.
EnQue
(
sin_ub
);
_cos_que
.
EnQue
(
cos_ub
);
}
template
<
typename
T
,
typename
U
>
__aicore__
inline
void
RoPEKernel
<
T
,
U
>::
compute
(
size_t
i
)
{
LocalTensor
<
T
>
input_ub
=
_in_que
.
DeQue
<
T
>
();
LocalTensor
<
T
>
sin_ub
=
_sin_que
.
DeQue
<
T
>
();
LocalTensor
<
T
>
cos_ub
=
_cos_que
.
DeQue
<
T
>
();
LocalTensor
<
T
>
output_ub
=
_out_que
.
AllocTensor
<
T
>
();
LocalTensor
<
T
>
tmp_odd
=
_tmp_odd_buf
.
Get
<
T
>
();
LocalTensor
<
T
>
tmp_even
=
_tmp_even_buf
.
Get
<
T
>
();
LocalTensor
<
T
>
tmp_odd1
=
_tmp_odd_buf1
.
Get
<
T
>
();
LocalTensor
<
T
>
tmp_odd2
=
_tmp_odd_buf2
.
Get
<
T
>
();
LocalTensor
<
T
>
tmp_even1
=
_tmp_even_buf1
.
Get
<
T
>
();
LocalTensor
<
T
>
tmp_even2
=
_tmp_even_buf2
.
Get
<
T
>
();
// separate odd and even bit elements
uint64_t
rsvdCnt
=
0
;
GatherMaskParams
gMaskParams
=
{
1
,
static_cast
<
uint16_t
>
((
_tile_len
*
sizeof
(
T
)
+
255
)
/
256
),
// no more than 256(<=255)
8
,
8
,
};
GatherMask
<
T
>
(
tmp_odd
,
input_ub
,
1
,
false
,
0
,
gMaskParams
,
rsvdCnt
);
GatherMask
<
T
>
(
tmp_even
,
input_ub
,
2
,
false
,
0
,
gMaskParams
,
rsvdCnt
);
PipeBarrier
<
PIPE_V
>
();
// compute odd bit elements
// y_odd = x_odd * cos - x_even * sin
Mul
<
T
>
(
tmp_odd1
,
tmp_odd
,
cos_ub
,
_tile_len
/
2
);
Mul
<
T
>
(
tmp_odd2
,
tmp_even
,
sin_ub
,
_tile_len
/
2
);
PipeBarrier
<
PIPE_V
>
();
Sub
<
T
>
(
tmp_odd1
,
tmp_odd1
,
tmp_odd2
,
_tile_len
/
2
);
// compute even bit elements
// y_even = x_odd * sin + x_even * cos
Mul
<
T
>
(
tmp_even1
,
tmp_odd
,
sin_ub
,
_tile_len
/
2
);
Mul
<
T
>
(
tmp_even2
,
tmp_even
,
cos_ub
,
_tile_len
/
2
);
PipeBarrier
<
PIPE_V
>
();
Add
<
T
>
(
tmp_even1
,
tmp_even1
,
tmp_even2
,
_tile_len
/
2
);
// combine odd and even bit elements
for
(
uint32_t
j
=
0
;
j
<
_tile_len
/
2
;
j
+=
1
)
{
output_ub
(
j
*
2
)
=
tmp_odd1
(
j
);
output_ub
(
j
*
2
+
1
)
=
tmp_even1
(
j
);
}
_out_que
.
EnQue
<
T
>
(
output_ub
);
_in_que
.
FreeTensor
(
input_ub
);
_sin_que
.
FreeTensor
(
sin_ub
);
_cos_que
.
FreeTensor
(
cos_ub
);
}
template
<
typename
T
,
typename
U
>
__aicore__
inline
void
RoPEKernel
<
T
,
U
>::
copyOut
(
size_t
i
)
{
LocalTensor
<
T
>
output_ub
=
_out_que
.
DeQue
<
T
>
();
auto
idy
=
i
*
_st_ynt
+
_block_idx
*
_st_ynh
;
DataCopyExtParams
params
=
{
1
,
static_cast
<
uint32_t
>
(
_tile_len
*
sizeof
(
T
)),
0
,
0
,
0
};
DataCopyPad
(
_y_gm
[
idy
],
output_ub
,
params
);
_out_que
.
FreeTensor
(
output_ub
);
}
template
<
typename
T
,
typename
U
>
__aicore__
inline
void
RoPEKernel
<
T
,
U
>::
process
(
size_t
seq_len
)
{
for
(
size_t
i
=
0
;
i
<
seq_len
;
++
i
)
{
copyIn
(
i
);
compute
(
i
);
copyOut
(
i
);
}
}
#define ROPE_KERNEL_INIT_ARGS y, x, pos, sin, cos, dhead, \
y_stride_seqlen, y_stride_nhead, \
x_stride_seqlen, x_stride_nhead
#define CASE_POSTYPE(POS_TYPE_ENUM, TYPE, POS_T) \
case POS_TYPE_ENUM: { \
RoPEKernel<TYPE, POS_T> op; \
op.init(ROPE_KERNEL_INIT_ARGS); \
op.process(seq_len); \
break; \
}
#define ROPE_KERNEL(TYPE, POSTYPE) \
switch (POSTYPE) { \
CASE_POSTYPE(INFINI_DTYPE_I8, TYPE, int8_t) \
CASE_POSTYPE(INFINI_DTYPE_I16, TYPE, int16_t) \
CASE_POSTYPE(INFINI_DTYPE_I32, TYPE, int32_t) \
CASE_POSTYPE(INFINI_DTYPE_I64, TYPE, int64_t) \
CASE_POSTYPE(INFINI_DTYPE_U8, TYPE, uint8_t) \
CASE_POSTYPE(INFINI_DTYPE_U16, TYPE, uint16_t) \
CASE_POSTYPE(INFINI_DTYPE_U32, TYPE, uint32_t) \
CASE_POSTYPE(INFINI_DTYPE_U64, TYPE, uint64_t) \
default: \
break; \
}
#define DEFINE_ROPE_KERNEL(KERNEL_NAME, TYPE) \
__global__ __aicore__ void KERNEL_NAME(GM_ADDR y, \
GM_ADDR x, \
GM_ADDR pos, \
GM_ADDR sin, \
GM_ADDR cos, \
size_t seq_len, \
size_t dhead, \
ptrdiff_t y_stride_seqlen, \
ptrdiff_t y_stride_nhead, \
ptrdiff_t x_stride_seqlen, \
ptrdiff_t x_stride_nhead, \
int32_t pos_type) { \
ROPE_KERNEL(TYPE, pos_type) \
}
DEFINE_ROPE_KERNEL
(
rope_kernel_float
,
float
)
DEFINE_ROPE_KERNEL
(
rope_kernel_half
,
half
)
#undef DEFINE_ROPE_KERNEL
#undef ROPE_KERNEL
#undef CASE_POSTYPE
#undef ROPE_KERNEL_INIT_ARGS
extern
"C"
infiniStatus_t
rope_kernel_launch
(
void
*
y
,
void
*
x
,
void
*
pos
,
void
*
sin
,
void
*
cos
,
size_t
seq_len
,
size_t
nhead
,
size_t
dhead
,
infiniDtype_t
dtype
,
infiniDtype_t
pos_type
,
ptrdiff_t
y_stride_seqlen
,
ptrdiff_t
y_stride_nhead
,
ptrdiff_t
x_stride_seqlen
,
ptrdiff_t
x_stride_nhead
,
void
*
stream
)
{
#define LAUNCH_ROPE_KERNEL(DTYPE_ENUM, KERNEL_NAME) \
case DTYPE_ENUM: \
KERNEL_NAME<<<nhead, nullptr, stream>>>(y, x, pos, sin, cos, \
seq_len, \
dhead, \
y_stride_seqlen, \
y_stride_nhead, \
x_stride_seqlen, \
x_stride_nhead, \
pos_type); \
return INFINI_STATUS_SUCCESS;
switch
(
dtype
)
{
LAUNCH_ROPE_KERNEL
(
INFINI_DTYPE_F16
,
rope_kernel_half
)
LAUNCH_ROPE_KERNEL
(
INFINI_DTYPE_F32
,
rope_kernel_float
)
default:
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
src/infiniop/ops/rope/operator.cc
View file @
1be004cb
...
@@ -8,6 +8,9 @@
...
@@ -8,6 +8,9 @@
#ifdef ENABLE_CUDA_API
#ifdef ENABLE_CUDA_API
#include "cuda/rope_cuda.cuh"
#include "cuda/rope_cuda.cuh"
#endif
#endif
#ifdef ENABLE_ASCEND_API
#include "ascend/rope_ascend.h"
#endif
__C
infiniStatus_t
infiniopCreateRoPEDescriptor
(
__C
infiniStatus_t
infiniopCreateRoPEDescriptor
(
infiniopHandle_t
handle
,
infiniopHandle_t
handle
,
...
@@ -43,12 +46,8 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor(
...
@@ -43,12 +46,8 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor(
pos_ids
,
sin_table
,
cos_table
);
pos_ids
,
sin_table
,
cos_table
);
}
}
#endif
#endif
#ifdef ENABLE_ASCEND_NPU
#ifdef ENABLE_ASCEND_API
case
DevAscendNpu
:
{
CREATE
(
INFINI_DEVICE_ASCEND
,
ascend
);
return
ascendCreateRoPEDescriptor
((
AscendHandle_t
)
handle
,
(
RoPEAscendDescriptor_t
*
)
desc_ptr
,
t
,
pos_ids
,
sin_table
,
cos_table
);
}
#endif
#endif
#ifdef ENABLE_METAX_GPU
#ifdef ENABLE_METAX_GPU
case
DevMetaxGpu
:
{
case
DevMetaxGpu
:
{
...
@@ -90,10 +89,8 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc,
...
@@ -90,10 +89,8 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc,
return
bangGetRoPEWorkspaceSize
((
RoPEBangDescriptor_t
)
desc
,
size
);
return
bangGetRoPEWorkspaceSize
((
RoPEBangDescriptor_t
)
desc
,
size
);
}
}
#endif
#endif
#ifdef ENABLE_ASCEND_NPU
#ifdef ENABLE_ASCEND_API
case
DevAscendNpu
:
{
GET
(
INFINI_DEVICE_ASCEND
,
ascend
);
return
ascendGetRoPEWorkspaceSize
((
RoPEAscendDescriptor_t
)
desc
,
size
);
}
#endif
#endif
#ifdef ENABLE_METAX_GPU
#ifdef ENABLE_METAX_GPU
case
DevMetaxGpu
:
{
case
DevMetaxGpu
:
{
...
@@ -141,12 +138,8 @@ __C infiniStatus_t infiniopRoPE(
...
@@ -141,12 +138,8 @@ __C infiniStatus_t infiniopRoPE(
t
,
pos_ids
,
sin_table
,
cos_table
,
stream
);
t
,
pos_ids
,
sin_table
,
cos_table
,
stream
);
}
}
#endif
#endif
#ifdef ENABLE_ASCEND_NPU
#ifdef ENABLE_ASCEND_API
case
DevAscendNpu
:
{
CALCULATE
(
INFINI_DEVICE_ASCEND
,
ascend
);
return
ascendRoPE
((
RoPEAscendDescriptor_t
)
desc
,
workspace
,
workspace_size
,
t
,
pos_ids
,
sin_table
,
cos_table
,
stream
);
}
#endif
#endif
#ifdef ENABLE_METAX_GPU
#ifdef ENABLE_METAX_GPU
case
DevMetaxGpu
:
{
case
DevMetaxGpu
:
{
...
@@ -187,10 +180,8 @@ infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) {
...
@@ -187,10 +180,8 @@ infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) {
return
bangDestroyRoPEDescriptor
((
RoPEBangDescriptor_t
)
desc
);
return
bangDestroyRoPEDescriptor
((
RoPEBangDescriptor_t
)
desc
);
}
}
#endif
#endif
#ifdef ENABLE_ASCEND_NPU
#ifdef ENABLE_ASCEND_API
case
DevAscendNpu
:
{
DELETE
(
INFINI_DEVICE_ASCEND
,
ascend
);
return
ascendDestroyRoPEDescriptor
((
RoPEAscendDescriptor_t
)
desc
);
}
#endif
#endif
#ifdef ENABLE_METAX_GPU
#ifdef ENABLE_METAX_GPU
case
DevMetaxGpu
:
{
case
DevMetaxGpu
:
{
...
...
src/infiniop/ops/swiglu/ascend/swiglu_ascend.cc
View file @
1be004cb
...
@@ -26,12 +26,6 @@ infiniStatus_t Descriptor::create(infiniopHandle_t handle, Descriptor **desc_ptr
...
@@ -26,12 +26,6 @@ infiniStatus_t Descriptor::create(infiniopHandle_t handle, Descriptor **desc_ptr
return
INFINI_STATUS_SUCCESS
;
return
INFINI_STATUS_SUCCESS
;
}
}
extern
"C"
infiniStatus_t
swiglu_kernel_launch
(
void
*
c
,
void
*
a
,
void
*
b
,
infiniDtype_t
dtype
,
size_t
batch
,
size_t
seq
,
size_t
hd
,
ptrdiff_t
stride_batch_c
,
ptrdiff_t
stride_batch_a
,
ptrdiff_t
stride_batch_b
,
ptrdiff_t
stride_seq_c
,
ptrdiff_t
stride_seq_a
,
ptrdiff_t
stride_seq_b
,
void
*
stream
);
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
size_t
workspace_size
,
void
*
c
,
void
*
c
,
...
...
src/infiniop/ops/swiglu/ascend/swiglu_ascend.h
View file @
1be004cb
...
@@ -69,5 +69,11 @@ public:
...
@@ -69,5 +69,11 @@ public:
void
*
stream
)
const
;
void
*
stream
)
const
;
};
};
extern
"C"
infiniStatus_t
swiglu_kernel_launch
(
void
*
c
,
void
*
a
,
void
*
b
,
infiniDtype_t
dtype
,
size_t
batch
,
size_t
seq
,
size_t
hd
,
ptrdiff_t
stride_batch_c
,
ptrdiff_t
stride_batch_a
,
ptrdiff_t
stride_batch_b
,
ptrdiff_t
stride_seq_c
,
ptrdiff_t
stride_seq_a
,
ptrdiff_t
stride_seq_b
,
void
*
stream
);
}
// namespace op::swiglu::ascend
}
// namespace op::swiglu::ascend
#endif // __ACLNN_SWIGLU_H__
#endif // __ACLNN_SWIGLU_H__
src/infiniop/ops/swiglu/ascend/swiglu_ascend_kernel.cpp
View file @
1be004cb
...
@@ -6,15 +6,20 @@ template <typename T>
...
@@ -6,15 +6,20 @@ template <typename T>
class
SwigluKernel
{
class
SwigluKernel
{
public:
public:
__aicore__
inline
SwigluKernel
()
{}
__aicore__
inline
SwigluKernel
()
{}
__aicore__
inline
void
init
(
GM_ADDR
c
,
GM_ADDR
a
,
GM_ADDR
b
,
int64_t
batch_
,
int64_t
seq
,
int64_t
hd
,
__aicore__
inline
void
init
(
GM_ADDR
c
,
GM_ADDR
a
,
GM_ADDR
b
,
int64_t
stride_batch_c
,
int64_t
stride_batch_a
,
int64_t
stride_batch_b
,
size_t
batch_
,
size_t
seq
,
size_t
hd
,
int64_t
stride_seq_c
,
int64_t
stride_seq_a
,
int64_t
stride_seq_b
);
ptrdiff_t
stride_batch_c
,
ptrdiff_t
stride_batch_a
,
ptrdiff_t
stride_batch_b
,
ptrdiff_t
stride_seq_c
,
ptrdiff_t
stride_seq_a
,
ptrdiff_t
stride_seq_b
);
__aicore__
inline
void
process
();
__aicore__
inline
void
process
();
private:
private:
__aicore__
inline
void
copyIn
(
int64
_t
i
);
__aicore__
inline
void
copyIn
(
size
_t
i
);
__aicore__
inline
void
compute
(
int64
_t
i
);
__aicore__
inline
void
compute
(
size
_t
i
);
__aicore__
inline
void
copyOut
(
int64
_t
i
);
__aicore__
inline
void
copyOut
(
size
_t
i
);
private:
private:
GlobalTensor
<
T
>
_c_gm
,
_a_gm
,
_b_gm
;
GlobalTensor
<
T
>
_c_gm
,
_a_gm
,
_b_gm
;
...
@@ -23,16 +28,21 @@ private:
...
@@ -23,16 +28,21 @@ private:
TPipe
_pipe
;
TPipe
_pipe
;
float
_beta_value
=
1.0
f
;
float
_beta_value
=
1.0
f
;
int64
_t
_block_idx
,
_tile_len
,
_copy_len
,
size
_t
_block_idx
,
_tile_len
,
_copy_len
,
_batch
,
_seq_len
,
_hidden_size
,
_batch
,
_seq_len
,
_hidden_size
,
_stride_seq_a
,
_stride_seq_b
,
_stride_seq_c
;
_stride_seq_a
,
_stride_seq_b
,
_stride_seq_c
;
int64_t
_stride_batch_a
=
1
,
_stride_batch_b
=
1
,
_stride_batch_c
=
1
;
int64_t
_stride_batch_a
=
1
,
_stride_batch_b
=
1
,
_stride_batch_c
=
1
;
};
};
template
<
typename
T
>
template
<
typename
T
>
__aicore__
inline
void
SwigluKernel
<
T
>::
init
(
GM_ADDR
c
,
GM_ADDR
a
,
GM_ADDR
b
,
int64_t
batch_
,
int64_t
seq
,
int64_t
hd
,
__aicore__
inline
void
SwigluKernel
<
T
>::
init
(
GM_ADDR
c
,
GM_ADDR
a
,
GM_ADDR
b
,
int64_t
stride_batch_c
,
int64_t
stride_batch_a
,
int64_t
stride_batch_b
,
size_t
batch_
,
size_t
seq
,
size_t
hd
,
int64_t
stride_seq_c
,
int64_t
stride_seq_a
,
int64_t
stride_seq_b
)
{
ptrdiff_t
stride_batch_c
,
ptrdiff_t
stride_batch_a
,
ptrdiff_t
stride_batch_b
,
ptrdiff_t
stride_seq_c
,
ptrdiff_t
stride_seq_a
,
ptrdiff_t
stride_seq_b
)
{
// Init Shape & StrideVariables
// Init Shape & StrideVariables
_batch
=
batch_
;
_batch
=
batch_
;
_seq_len
=
seq
;
_seq_len
=
seq
;
...
@@ -46,7 +56,7 @@ __aicore__ inline void SwigluKernel<T>::init(GM_ADDR c, GM_ADDR a, GM_ADDR b, in
...
@@ -46,7 +56,7 @@ __aicore__ inline void SwigluKernel<T>::init(GM_ADDR c, GM_ADDR a, GM_ADDR b, in
_block_idx
=
GetBlockIdx
();
_block_idx
=
GetBlockIdx
();
_tile_len
=
_block_idx
<
(
_hidden_size
%
BLOCK_NUM
)
?
(
_hidden_size
/
BLOCK_NUM
)
+
1
:
(
_hidden_size
/
BLOCK_NUM
);
_tile_len
=
_block_idx
<
(
_hidden_size
%
BLOCK_NUM
)
?
(
_hidden_size
/
BLOCK_NUM
)
+
1
:
(
_hidden_size
/
BLOCK_NUM
);
_copy_len
=
(
_tile_len
*
sizeof
(
T
))
%
BYTE_ALIGN
==
0
?
_tile_len
:
(
_tile_len
*
sizeof
(
T
)
+
(
BYTE_ALIGN
-
_tile_len
*
sizeof
(
T
)
%
BYTE_ALIGN
))
/
sizeof
(
T
);
_copy_len
=
alignTileLen
<
T
>
(
_tile_len
,
BYTE_ALIGN
);
// Set global tensor
// Set global tensor
_a_gm
.
SetGlobalBuffer
((
__gm__
T
*
)
a
);
_a_gm
.
SetGlobalBuffer
((
__gm__
T
*
)
a
);
...
@@ -60,7 +70,7 @@ __aicore__ inline void SwigluKernel<T>::init(GM_ADDR c, GM_ADDR a, GM_ADDR b, in
...
@@ -60,7 +70,7 @@ __aicore__ inline void SwigluKernel<T>::init(GM_ADDR c, GM_ADDR a, GM_ADDR b, in
}
}
template
<
typename
T
>
template
<
typename
T
>
__aicore__
inline
void
SwigluKernel
<
T
>::
copyIn
(
int64
_t
i
)
{
__aicore__
inline
void
SwigluKernel
<
T
>::
copyIn
(
size
_t
i
)
{
// Alloc tensor from queue memory
// Alloc tensor from queue memory
LocalTensor
<
T
>
aLocal
=
_in_queue_a
.
AllocTensor
<
T
>
();
LocalTensor
<
T
>
aLocal
=
_in_queue_a
.
AllocTensor
<
T
>
();
LocalTensor
<
T
>
bLocal
=
_in_queue_b
.
AllocTensor
<
T
>
();
LocalTensor
<
T
>
bLocal
=
_in_queue_b
.
AllocTensor
<
T
>
();
...
@@ -68,8 +78,8 @@ __aicore__ inline void SwigluKernel<T>::copyIn(int64_t i) {
...
@@ -68,8 +78,8 @@ __aicore__ inline void SwigluKernel<T>::copyIn(int64_t i) {
auto
batch_idx
=
_batch
==
1
?
0
:
i
/
_seq_len
;
auto
batch_idx
=
_batch
==
1
?
0
:
i
/
_seq_len
;
auto
seq_idx
=
_batch
==
1
?
i
:
i
%
_seq_len
;
auto
seq_idx
=
_batch
==
1
?
i
:
i
%
_seq_len
;
int64
_t
idxa
=
batch_idx
*
_stride_batch_a
+
seq_idx
*
_stride_seq_a
+
_block_idx
*
_tile_len
;
ptrdiff
_t
idxa
=
batch_idx
*
_stride_batch_a
+
seq_idx
*
_stride_seq_a
+
_block_idx
*
_tile_len
;
int64
_t
idxb
=
batch_idx
*
_stride_batch_b
+
seq_idx
*
_stride_seq_b
+
_block_idx
*
_tile_len
;
ptrdiff
_t
idxb
=
batch_idx
*
_stride_batch_b
+
seq_idx
*
_stride_seq_b
+
_block_idx
*
_tile_len
;
// Copy process_th tile from global tensor to local tensor
// Copy process_th tile from global tensor to local tensor
DataCopy
(
aLocal
,
_a_gm
[
idxa
],
_copy_len
);
DataCopy
(
aLocal
,
_a_gm
[
idxa
],
_copy_len
);
DataCopy
(
bLocal
,
_b_gm
[
idxb
],
_copy_len
);
DataCopy
(
bLocal
,
_b_gm
[
idxb
],
_copy_len
);
...
@@ -80,7 +90,7 @@ __aicore__ inline void SwigluKernel<T>::copyIn(int64_t i) {
...
@@ -80,7 +90,7 @@ __aicore__ inline void SwigluKernel<T>::copyIn(int64_t i) {
}
}
template
<
typename
T
>
template
<
typename
T
>
__aicore__
inline
void
SwigluKernel
<
T
>::
compute
(
int64
_t
i
)
{
__aicore__
inline
void
SwigluKernel
<
T
>::
compute
(
size
_t
i
)
{
// Deque input tensors from VECIN queue
// Deque input tensors from VECIN queue
LocalTensor
<
T
>
aLocal
=
_in_queue_a
.
DeQue
<
T
>
();
LocalTensor
<
T
>
aLocal
=
_in_queue_a
.
DeQue
<
T
>
();
LocalTensor
<
T
>
bLocal
=
_in_queue_b
.
DeQue
<
T
>
();
LocalTensor
<
T
>
bLocal
=
_in_queue_b
.
DeQue
<
T
>
();
...
@@ -94,12 +104,12 @@ __aicore__ inline void SwigluKernel<T>::compute(int64_t i) {
...
@@ -94,12 +104,12 @@ __aicore__ inline void SwigluKernel<T>::compute(int64_t i) {
}
}
template
<
typename
T
>
template
<
typename
T
>
__aicore__
inline
void
SwigluKernel
<
T
>::
copyOut
(
int64
_t
i
)
{
__aicore__
inline
void
SwigluKernel
<
T
>::
copyOut
(
size
_t
i
)
{
// Deque output tensor from VECOUT queue
// Deque output tensor from VECOUT queue
LocalTensor
<
T
>
cLocal
=
_out_queue_c
.
DeQue
<
T
>
();
LocalTensor
<
T
>
cLocal
=
_out_queue_c
.
DeQue
<
T
>
();
auto
batch_idx
=
_batch
==
1
?
0
:
i
/
_seq_len
;
auto
batch_idx
=
_batch
==
1
?
0
:
i
/
_seq_len
;
auto
seq_idx
=
_batch
==
1
?
i
:
i
%
_seq_len
;
auto
seq_idx
=
_batch
==
1
?
i
:
i
%
_seq_len
;
int64
_t
idxc
=
batch_idx
*
_stride_batch_c
+
seq_idx
*
_stride_seq_c
+
_block_idx
*
_tile_len
;
ptrdiff
_t
idxc
=
batch_idx
*
_stride_batch_c
+
seq_idx
*
_stride_seq_c
+
_block_idx
*
_tile_len
;
// Copy progress_th tile from local tensor to global tensor
// Copy progress_th tile from local tensor to global tensor
if
(
_tile_len
*
sizeof
(
T
)
%
BYTE_ALIGN
!=
0
)
{
if
(
_tile_len
*
sizeof
(
T
)
%
BYTE_ALIGN
!=
0
)
{
DataCopyExtParams
dcep
=
{
1
,
static_cast
<
uint32_t
>
(
_tile_len
*
sizeof
(
T
)),
0
,
0
,
0
};
DataCopyExtParams
dcep
=
{
1
,
static_cast
<
uint32_t
>
(
_tile_len
*
sizeof
(
T
)),
0
,
0
,
0
};
...
@@ -113,28 +123,28 @@ __aicore__ inline void SwigluKernel<T>::copyOut(int64_t i) {
...
@@ -113,28 +123,28 @@ __aicore__ inline void SwigluKernel<T>::copyOut(int64_t i) {
template
<
typename
T
>
template
<
typename
T
>
__aicore__
inline
void
SwigluKernel
<
T
>::
process
()
{
__aicore__
inline
void
SwigluKernel
<
T
>::
process
()
{
for
(
int64
_t
i
=
0
;
i
<
_batch
*
_seq_len
;
++
i
)
{
for
(
size
_t
i
=
0
;
i
<
_batch
*
_seq_len
;
++
i
)
{
copyIn
(
i
);
copyIn
(
i
);
compute
(
i
);
compute
(
i
);
copyOut
(
i
);
copyOut
(
i
);
}
}
}
}
#define DEFINE_SWIGLU_KERNEL(KERNEL_NAME, TYPE)
\
#define DEFINE_SWIGLU_KERNEL(KERNEL_NAME, TYPE) \
__global__ __aicore__ void KERNEL_NAME(GM_ADDR c, GM_ADDR a, GM_ADDR b,
\
__global__ __aicore__ void KERNEL_NAME(GM_ADDR c, GM_ADDR a, GM_ADDR b, \
int64
_t batch,
int64
_t seq,
int64
_t hd, \
size
_t batch,
size
_t seq,
size
_t hd, \
int64
_t stride_batch_c,
\
ptrdiff
_t stride_batch_c, \
int64
_t stride_batch_a,
\
ptrdiff
_t stride_batch_a, \
int64
_t stride_batch_b,
\
ptrdiff
_t stride_batch_b, \
int64
_t stride_seq_c,
\
ptrdiff
_t stride_seq_c, \
int64
_t stride_seq_a,
\
ptrdiff
_t stride_seq_a, \
int64
_t stride_seq_b) {
\
ptrdiff
_t stride_seq_b) { \
SwigluKernel<TYPE> op;
\
SwigluKernel<TYPE> op; \
op.init(c, a, b,
\
op.init(c, a, b, \
batch, seq, hd,
\
batch, seq, hd, \
stride_batch_c, stride_batch_a, stride_batch_b,
\
stride_batch_c, stride_batch_a, stride_batch_b, \
stride_seq_c, stride_seq_a, stride_seq_b);
\
stride_seq_c, stride_seq_a, stride_seq_b); \
op.process();
\
op.process(); \
}
}
DEFINE_SWIGLU_KERNEL
(
swiglu_kernel_half
,
half
)
DEFINE_SWIGLU_KERNEL
(
swiglu_kernel_half
,
half
)
...
@@ -152,9 +162,9 @@ extern "C" infiniStatus_t swiglu_kernel_launch(
...
@@ -152,9 +162,9 @@ extern "C" infiniStatus_t swiglu_kernel_launch(
case DTYPE_ENUM: \
case DTYPE_ENUM: \
KERNEL_NAME<<<BLOCK_NUM, nullptr, stream>>>( \
KERNEL_NAME<<<BLOCK_NUM, nullptr, stream>>>( \
c, a, b, \
c, a, b, \
static_cast<int64_t>(batch),
\
batch,
\
s
tatic_cast<int64_t>(seq),
\
s
eq,
\
static_cast<int64_t>(hd),
\
hd,
\
stride_batch_c, stride_batch_a, stride_batch_b, \
stride_batch_c, stride_batch_a, stride_batch_b, \
stride_seq_c, stride_seq_a, stride_seq_b); \
stride_seq_c, stride_seq_a, stride_seq_b); \
return INFINI_STATUS_SUCCESS;
return INFINI_STATUS_SUCCESS;
...
...
test/infiniop/rope.py
View file @
1be004cb
...
@@ -189,6 +189,9 @@ def test(
...
@@ -189,6 +189,9 @@ def test(
)
)
lib_rope
()
lib_rope
()
if
sync
is
not
None
:
sync
()
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
if
DEBUG
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment