Commit 00842e65 authored by zhangyunze's avatar zhangyunze Committed by zhangyue
Browse files

feat: 添加昇腾rope算子

parent 0c803397
...@@ -23,10 +23,9 @@ include_directories( ...@@ -23,10 +23,9 @@ include_directories(
${CMAKE_SOURCE_DIR}/../../../../include/infiniop/ ${CMAKE_SOURCE_DIR}/../../../../include/infiniop/
) )
ascendc_library(ascend_kernels STATIC ascendc_library(ascend_kernels STATIC
../../ops/swiglu/ascend/swiglu_ascend_kernel.cpp ../../ops/swiglu/ascend/swiglu_ascend_kernel.cpp
# ../../ops/rotary_embedding/ascend/rotary_embedding_kernel.cpp ../../ops/rope/ascend/rope_kernel.cpp
# ../../ops/random_sample/ascend/random_sample_kernel.cpp # ../../ops/random_sample/ascend/random_sample_kernel.cpp
) )
#include "rope_aclnn.h"
#include "../../../devices/ascend/common_ascend.h"
namespace op::rope::ascend {
Descriptor::~Descriptor()
= default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
auto handle_ascned = reinterpret_cast<device::ascend::Handle *>(handle);
auto result = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
CHECK_RESULT(result);
size_t workspace_size = 0;
*desc_ptr = new Descriptor(std::move(result.take()), workspace_size, nullptr, handle_ascned->device, handle_ascned->device_id);
return INFINI_STATUS_SUCCESS;
}
extern "C" infiniStatus_t rope_kernel_launch(
void *y,
void *x,
void *pos,
void *sin,
void *cos,
int32_t seq_len,
int32_t nhead,
int32_t dhead,
int32_t data_type,
int32_t y_stride_seqlen,
int32_t y_stride_nhead,
int32_t x_stride_seqlen,
int32_t x_stride_nhead,
void *stream);
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *y,
const void *x,
const void *pos_ids,
const void *sin_table,
const void *cos_table,
void *stream) const {
// TODO: 是否强加这个判断
std::cout << "pos_type: " << _info.pos_type << std::endl;
CHECK_DTYPE(_info.pos_type, INFINI_DTYPE_U32);
CHECK_DTYPE(_info.data_type, INFINI_DTYPE_F32, INFINI_DTYPE_F16);
int32_t seq_len = _info.seqlen;
int32_t nhead = _info.nhead;
int32_t dhead = _info.dhead;
int32_t data_type = _info.data_type;
int32_t y_stride_seqlen = _info.y_stride_seqlen;
int32_t y_stride_nhead = _info.y_stride_nhead;
int32_t x_stride_seqlen = _info.x_stride_seqlen;
int32_t x_stride_nhead = _info.x_stride_nhead;
std::cout << "shape is " << seq_len << ", " << nhead << ", " << dhead << std::endl;
std::cout << "y_stride is " << y_stride_seqlen << ", " << y_stride_nhead << ", 1" << std::endl;
std::cout << "x_stride is " << x_stride_seqlen << ", " << x_stride_nhead << ", 1" << std::endl;
return rope_kernel_launch(y, (void *)x, (void *)pos_ids, (void *)sin_table, (void *)cos_table, seq_len, nhead, dhead, data_type, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead, stream);
}
} // namespace op::rope::ascend
#ifndef __ACLNN_ROPE_H__
#define __ACLNN_ROPE_H__
#include "../rope.h"
DESCRIPTOR(ascend)
#endif // __ACLNN_ROPE_H__
#include "../../../../../include/infinicore.h"
#include "kernel_operator.h"
constexpr int32_t BLOCK_NUM = 8;
constexpr int32_t BUFFER_NUM = 2;
constexpr int32_t BYTE_ALIGN = 32;
using namespace AscendC;
template <typename T>
class RoPEKernel {
public:
__aicore__ inline RoPEKernel() {}
// Init op
// pos position vector
// x input tensor
// y output tensor
// tensor shape [nt, nh, dh]
// make block_num = nh, tile_len = dh
__aicore__ inline void Init(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM_ADDR sin, GM_ADDR cos,
int32_t dh,
int32_t st_ynt, int32_t st_ynh,
int32_t st_xnt, int32_t st_xnh);
__aicore__ inline void Process(int32_t seq_len);
private:
// Copy a tile into UB
__aicore__ inline void CopyIn(int32_t i);
__aicore__ inline void Compute(int32_t i);
__aicore__ inline void CopyOut(int32_t i);
private:
TPipe pipe;
TQue<QuePosition::VECIN, BUFFER_NUM> inQue;
TQue<QuePosition::VECIN, BUFFER_NUM> sinQue;
TQue<QuePosition::VECIN, BUFFER_NUM> cosQue;
TQue<QuePosition::VECOUT, BUFFER_NUM> outQue;
TBuf<TPosition::VECCALC> tmpOddBuf;
TBuf<TPosition::VECCALC> tmpEvenBuf;
TBuf<TPosition::VECCALC> tmpOddBuf1;
TBuf<TPosition::VECCALC> tmpOddBuf2;
TBuf<TPosition::VECCALC> tmpEvenBuf1;
TBuf<TPosition::VECCALC> tmpEvenBuf2;
GlobalTensor<T> xGm, yGm;
GlobalTensor<uint32_t> pGm;
GlobalTensor<T> sinGm;
GlobalTensor<T> cosGm;
uint32_t _block_idx;
uint32_t _tile_len;
// stridey[st_ynt_, st_ynh_, 1]
int32_t st_ynt_;
int32_t st_ynh_;
// stridex[st_xnt_, st_xnh_, 1]
int32_t st_xnt_;
int32_t st_xnh_;
};
template <typename T>
__aicore__ inline void RoPEKernel<T>::Init(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM_ADDR sin, GM_ADDR cos,
int32_t dh,
int32_t st_ynt, int32_t st_ynh,
int32_t st_xnt, int32_t st_xnh) {
this->_tile_len = dh;
this->st_ynt_ = st_ynt;
this->st_ynh_ = st_ynh;
this->st_xnt_ = st_xnt;
this->st_xnh_ = st_xnh;
_block_idx = GetBlockIdx();
// Init global buffer
xGm.SetGlobalBuffer((__gm__ T *)x);
pGm.SetGlobalBuffer(reinterpret_cast<__gm__ uint32_t *>(pos));
sinGm.SetGlobalBuffer((__gm__ T *)sin);
cosGm.SetGlobalBuffer((__gm__ T *)cos);
yGm.SetGlobalBuffer((__gm__ T *)y);
// Init Queue buffer
pipe.InitBuffer(inQue, BUFFER_NUM, _tile_len * sizeof(T));
pipe.InitBuffer(outQue, BUFFER_NUM, _tile_len * sizeof(T));
pipe.InitBuffer(sinQue, BUFFER_NUM, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(cosQue, BUFFER_NUM, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(tmpOddBuf, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(tmpEvenBuf, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(tmpOddBuf1, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(tmpOddBuf2, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(tmpEvenBuf1, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(tmpEvenBuf2, _tile_len / 2 * sizeof(T));
}
template <typename T>
__aicore__ inline void RoPEKernel<T>::CopyIn(int32_t i) {
LocalTensor<T> inputUb = inQue.AllocTensor<T>();
LocalTensor<T> sinUb = sinQue.AllocTensor<T>();
LocalTensor<T> cosUb = cosQue.AllocTensor<T>();
// Get idx of current tile in total input
auto idx = i * st_xnt_ + _block_idx * st_xnh_;
// Copy tile current tile into UB
DataCopy(inputUb, xGm[idx], _tile_len);
// Copy sin cos tile
auto pos_idx = pGm(i);
DataCopy(sinUb, sinGm[pos_idx * _tile_len / 2], _tile_len / 2);
DataCopy(cosUb, cosGm[pos_idx * _tile_len / 2], _tile_len / 2);
// Push in operands
inQue.EnQue(inputUb);
sinQue.EnQue(sinUb);
cosQue.EnQue(cosUb);
}
template <typename T>
__aicore__ inline void RoPEKernel<T>::Compute(int32_t i) {
LocalTensor<T> inputUb = inQue.DeQue<T>();
LocalTensor<T> sinUb = sinQue.DeQue<T>();
LocalTensor<T> cosUb = cosQue.DeQue<T>();
LocalTensor<T> outputUb = outQue.AllocTensor<T>();
LocalTensor<T> tmpOdd = tmpOddBuf.Get<T>();
LocalTensor<T> tmpEven = tmpEvenBuf.Get<T>();
LocalTensor<T> tmpOdd1 = tmpOddBuf1.Get<T>();
LocalTensor<T> tmpOdd2 = tmpOddBuf2.Get<T>();
LocalTensor<T> tmpEven1 = tmpEvenBuf1.Get<T>();
LocalTensor<T> tmpEven2 = tmpEvenBuf2.Get<T>();
// separate odd and even bit elements
uint64_t rsvdCnt = 0;
GatherMaskParams gMaskParams = {
1,
static_cast<uint16_t>((_tile_len * sizeof(T) + 255) / 256), // no more than 256(<=255)
8,
8,
};
GatherMask<T>(tmpOdd, inputUb, 1, false, 0, gMaskParams, rsvdCnt);
GatherMask<T>(tmpEven, inputUb, 2, false, 0, gMaskParams, rsvdCnt);
PipeBarrier<PIPE_V>();
// compute odd bit elements
// y_odd = x_odd * cos - x_even * sin
Mul<T>(tmpOdd1, tmpOdd, cosUb, _tile_len / 2);
Mul<T>(tmpOdd2, tmpEven, sinUb, _tile_len / 2);
PipeBarrier<PIPE_V>();
Sub<T>(tmpOdd1, tmpOdd1, tmpOdd2, _tile_len / 2);
// compute even bit elements
// y_even = x_odd * sin + x_even * cos
Mul<T>(tmpEven1, tmpOdd, sinUb, _tile_len / 2);
Mul<T>(tmpEven2, tmpEven, cosUb, _tile_len / 2);
PipeBarrier<PIPE_V>();
Add<T>(tmpEven1, tmpEven1, tmpEven2, _tile_len / 2);
// combine odd and even bit elements
for (uint32_t j = 0; j < _tile_len / 2; j += 1) {
outputUb(j * 2) = tmpOdd1(j);
outputUb(j * 2 + 1) = tmpEven1(j);
}
outQue.EnQue<T>(outputUb);
inQue.FreeTensor(inputUb);
sinQue.FreeTensor(sinUb);
cosQue.FreeTensor(cosUb);
}
template <typename T>
__aicore__ inline void RoPEKernel<T>::CopyOut(int32_t i) {
LocalTensor<T> outputUb = outQue.DeQue<T>();
auto idy = i * st_ynt_ + _block_idx * st_ynh_;
DataCopyExtParams params = {1, static_cast<uint32_t>(_tile_len * sizeof(T)), 0, 0, 0};
DataCopyPad(yGm[idy], outputUb, params);
outQue.FreeTensor(outputUb);
}
template <typename T>
__aicore__ inline void RoPEKernel<T>::Process(int32_t nt) {
for (int32_t i = 0; i < nt; ++i) {
CopyIn(i);
Compute(i);
CopyOut(i);
}
}
__global__ __aicore__ void rope_f16_kernel(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM_ADDR sin, GM_ADDR cos,
int32_t seq_len, int32_t dhead,
int32_t y_stride_seqlen, int32_t y_stride_nhead,
int32_t x_stride_seqlen, int32_t x_stride_nhead) {
RoPEKernel<half> op;
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead);
op.Process(seq_len);
}
__global__ __aicore__ void rope_f32_kernel(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM_ADDR sin, GM_ADDR cos,
int32_t seq_len, int32_t dhead,
int32_t y_stride_seqlen, int32_t y_stride_nhead,
int32_t x_stride_seqlen, int32_t x_stride_nhead) {
RoPEKernel<float> op;
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead);
op.Process(seq_len);
}
extern "C" infiniStatus_t rope_kernel_launch(void *y,
void *x,
void *pos,
void *sin,
void *cos,
int32_t seq_len,
int32_t nhead,
int32_t dhead,
int32_t data_type,
int32_t y_stride_seqlen,
int32_t y_stride_nhead,
int32_t x_stride_seqlen,
int32_t x_stride_nhead,
void *stream) {
switch (data_type) {
case 12: // float16
rope_f16_kernel<<<nhead, nullptr, stream>>>(y, x, pos, sin, cos, seq_len, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead);
break;
case 13: // float32
rope_f32_kernel<<<nhead, nullptr, stream>>>(y, x, pos, sin, cos, seq_len, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead);
break;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
\ No newline at end of file
...@@ -8,6 +8,9 @@ ...@@ -8,6 +8,9 @@
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
#include "cuda/rope_cuda.cuh" #include "cuda/rope_cuda.cuh"
#endif #endif
#ifdef ENABLE_ASCEND_API
#include "ascend/rope_aclnn.h"
#endif
__C infiniStatus_t infiniopCreateRoPEDescriptor( __C infiniStatus_t infiniopCreateRoPEDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
...@@ -43,12 +46,8 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor( ...@@ -43,12 +46,8 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor(
pos_ids, sin_table, cos_table); pos_ids, sin_table, cos_table);
} }
#endif #endif
#ifdef ENABLE_ASCEND_NPU #ifdef ENABLE_ASCEND_API
case DevAscendNpu: { CREATE(INFINI_DEVICE_ASCEND, ascend);
return ascendCreateRoPEDescriptor((AscendHandle_t)handle,
(RoPEAscendDescriptor_t *)desc_ptr, t,
pos_ids, sin_table, cos_table);
}
#endif #endif
#ifdef ENABLE_METAX_GPU #ifdef ENABLE_METAX_GPU
case DevMetaxGpu: { case DevMetaxGpu: {
...@@ -90,10 +89,8 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc, ...@@ -90,10 +89,8 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc,
return bangGetRoPEWorkspaceSize((RoPEBangDescriptor_t)desc, size); return bangGetRoPEWorkspaceSize((RoPEBangDescriptor_t)desc, size);
} }
#endif #endif
#ifdef ENABLE_ASCEND_NPU #ifdef ENABLE_ASCEND_API
case DevAscendNpu: { GET(INFINI_DEVICE_ASCEND, ascend);
return ascendGetRoPEWorkspaceSize((RoPEAscendDescriptor_t)desc, size);
}
#endif #endif
#ifdef ENABLE_METAX_GPU #ifdef ENABLE_METAX_GPU
case DevMetaxGpu: { case DevMetaxGpu: {
...@@ -141,12 +138,8 @@ __C infiniStatus_t infiniopRoPE( ...@@ -141,12 +138,8 @@ __C infiniStatus_t infiniopRoPE(
t, pos_ids, sin_table, cos_table, stream); t, pos_ids, sin_table, cos_table, stream);
} }
#endif #endif
#ifdef ENABLE_ASCEND_NPU #ifdef ENABLE_ASCEND_API
case DevAscendNpu: { CALCULATE(INFINI_DEVICE_ASCEND, ascend);
return ascendRoPE((RoPEAscendDescriptor_t)desc, workspace,
workspace_size, t, pos_ids, sin_table, cos_table,
stream);
}
#endif #endif
#ifdef ENABLE_METAX_GPU #ifdef ENABLE_METAX_GPU
case DevMetaxGpu: { case DevMetaxGpu: {
...@@ -187,10 +180,8 @@ infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) { ...@@ -187,10 +180,8 @@ infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) {
return bangDestroyRoPEDescriptor((RoPEBangDescriptor_t)desc); return bangDestroyRoPEDescriptor((RoPEBangDescriptor_t)desc);
} }
#endif #endif
#ifdef ENABLE_ASCEND_NPU #ifdef ENABLE_ASCEND_API
case DevAscendNpu: { DELETE(INFINI_DEVICE_ASCEND, ascend);
return ascendDestroyRoPEDescriptor((RoPEAscendDescriptor_t)desc);
}
#endif #endif
#ifdef ENABLE_METAX_GPU #ifdef ENABLE_METAX_GPU
case DevMetaxGpu: { case DevMetaxGpu: {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment