Commit 2c0e9a6e authored by zhangyunze's avatar zhangyunze Committed by zhangyue
Browse files

fix:解除kernel中对postype的限制

parent bbb0105b
...@@ -33,6 +33,7 @@ extern "C" infiniStatus_t rope_kernel_launch( ...@@ -33,6 +33,7 @@ extern "C" infiniStatus_t rope_kernel_launch(
int32_t nhead, int32_t nhead,
int32_t dhead, int32_t dhead,
int32_t data_type, int32_t data_type,
int32_t pos_type,
int32_t y_stride_seqlen, int32_t y_stride_seqlen,
int32_t y_stride_nhead, int32_t y_stride_nhead,
int32_t x_stride_seqlen, int32_t x_stride_seqlen,
...@@ -48,17 +49,16 @@ infiniStatus_t Descriptor::calculate( ...@@ -48,17 +49,16 @@ infiniStatus_t Descriptor::calculate(
const void *sin_table, const void *sin_table,
const void *cos_table, const void *cos_table,
void *stream) const { void *stream) const {
// TODO: 是否有可能解除这个判断
CHECK_DTYPE(_info.pos_type, INFINI_DTYPE_U32);
CHECK_DTYPE(_info.data_type, INFINI_DTYPE_F32, INFINI_DTYPE_F16); CHECK_DTYPE(_info.data_type, INFINI_DTYPE_F32, INFINI_DTYPE_F16);
int32_t seq_len = _info.seqlen; int32_t seq_len = _info.seqlen;
int32_t nhead = _info.nhead; int32_t nhead = _info.nhead;
int32_t dhead = _info.dhead; int32_t dhead = _info.dhead;
int32_t data_type = _info.data_type; int32_t data_type = _info.data_type;
int32_t pos_type = _info.pos_type;
int32_t y_stride_seqlen = _info.y_stride_seqlen; int32_t y_stride_seqlen = _info.y_stride_seqlen;
int32_t y_stride_nhead = _info.y_stride_nhead; int32_t y_stride_nhead = _info.y_stride_nhead;
int32_t x_stride_seqlen = _info.x_stride_seqlen; int32_t x_stride_seqlen = _info.x_stride_seqlen;
int32_t x_stride_nhead = _info.x_stride_nhead; int32_t x_stride_nhead = _info.x_stride_nhead;
return rope_kernel_launch(y, (void *)x, (void *)pos_ids, (void *)sin_table, (void *)cos_table, seq_len, nhead, dhead, data_type, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead, stream); return rope_kernel_launch(y, (void *)x, (void *)pos_ids, (void *)sin_table, (void *)cos_table, seq_len, nhead, dhead, data_type, pos_type, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead, stream);
} }
} // namespace op::rope::ascend } // namespace op::rope::ascend
...@@ -7,7 +7,7 @@ constexpr int32_t BYTE_ALIGN = 32; ...@@ -7,7 +7,7 @@ constexpr int32_t BYTE_ALIGN = 32;
using namespace AscendC; using namespace AscendC;
template <typename T> template <typename T, typename U>
class RoPEKernel { class RoPEKernel {
public: public:
__aicore__ inline RoPEKernel() {} __aicore__ inline RoPEKernel() {}
...@@ -43,7 +43,7 @@ private: ...@@ -43,7 +43,7 @@ private:
TBuf<TPosition::VECCALC> tmpEvenBuf2; TBuf<TPosition::VECCALC> tmpEvenBuf2;
GlobalTensor<T> xGm, yGm; GlobalTensor<T> xGm, yGm;
GlobalTensor<uint32_t> pGm; GlobalTensor<U> pGm;
GlobalTensor<T> sinGm; GlobalTensor<T> sinGm;
GlobalTensor<T> cosGm; GlobalTensor<T> cosGm;
...@@ -60,8 +60,8 @@ private: ...@@ -60,8 +60,8 @@ private:
int32_t st_xnh_; int32_t st_xnh_;
}; };
template <typename T> template <typename T, typename U>
__aicore__ inline void RoPEKernel<T>::Init(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM_ADDR sin, GM_ADDR cos, __aicore__ inline void RoPEKernel<T, U>::Init(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM_ADDR sin, GM_ADDR cos,
int32_t dh, int32_t dh,
int32_t st_ynt, int32_t st_ynh, int32_t st_ynt, int32_t st_ynh,
int32_t st_xnt, int32_t st_xnh) { int32_t st_xnt, int32_t st_xnh) {
...@@ -77,7 +77,7 @@ __aicore__ inline void RoPEKernel<T>::Init(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM ...@@ -77,7 +77,7 @@ __aicore__ inline void RoPEKernel<T>::Init(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM
// Init global buffer // Init global buffer
xGm.SetGlobalBuffer((__gm__ T *)x); xGm.SetGlobalBuffer((__gm__ T *)x);
pGm.SetGlobalBuffer(reinterpret_cast<__gm__ uint32_t *>(pos)); pGm.SetGlobalBuffer((__gm__ U *)pos);
sinGm.SetGlobalBuffer((__gm__ T *)sin); sinGm.SetGlobalBuffer((__gm__ T *)sin);
cosGm.SetGlobalBuffer((__gm__ T *)cos); cosGm.SetGlobalBuffer((__gm__ T *)cos);
yGm.SetGlobalBuffer((__gm__ T *)y); yGm.SetGlobalBuffer((__gm__ T *)y);
...@@ -95,8 +95,8 @@ __aicore__ inline void RoPEKernel<T>::Init(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM ...@@ -95,8 +95,8 @@ __aicore__ inline void RoPEKernel<T>::Init(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM
pipe.InitBuffer(tmpEvenBuf2, _tile_len / 2 * sizeof(T)); pipe.InitBuffer(tmpEvenBuf2, _tile_len / 2 * sizeof(T));
} }
template <typename T> template <typename T, typename U>
__aicore__ inline void RoPEKernel<T>::CopyIn(int32_t i) { __aicore__ inline void RoPEKernel<T, U>::CopyIn(int32_t i) {
LocalTensor<T> inputUb = inQue.AllocTensor<T>(); LocalTensor<T> inputUb = inQue.AllocTensor<T>();
LocalTensor<T> sinUb = sinQue.AllocTensor<T>(); LocalTensor<T> sinUb = sinQue.AllocTensor<T>();
LocalTensor<T> cosUb = cosQue.AllocTensor<T>(); LocalTensor<T> cosUb = cosQue.AllocTensor<T>();
...@@ -114,8 +114,8 @@ __aicore__ inline void RoPEKernel<T>::CopyIn(int32_t i) { ...@@ -114,8 +114,8 @@ __aicore__ inline void RoPEKernel<T>::CopyIn(int32_t i) {
cosQue.EnQue(cosUb); cosQue.EnQue(cosUb);
} }
template <typename T> template <typename T, typename U>
__aicore__ inline void RoPEKernel<T>::Compute(int32_t i) { __aicore__ inline void RoPEKernel<T, U>::Compute(int32_t i) {
LocalTensor<T> inputUb = inQue.DeQue<T>(); LocalTensor<T> inputUb = inQue.DeQue<T>();
LocalTensor<T> sinUb = sinQue.DeQue<T>(); LocalTensor<T> sinUb = sinQue.DeQue<T>();
LocalTensor<T> cosUb = cosQue.DeQue<T>(); LocalTensor<T> cosUb = cosQue.DeQue<T>();
...@@ -166,8 +166,8 @@ __aicore__ inline void RoPEKernel<T>::Compute(int32_t i) { ...@@ -166,8 +166,8 @@ __aicore__ inline void RoPEKernel<T>::Compute(int32_t i) {
cosQue.FreeTensor(cosUb); cosQue.FreeTensor(cosUb);
} }
template <typename T> template <typename T, typename U>
__aicore__ inline void RoPEKernel<T>::CopyOut(int32_t i) { __aicore__ inline void RoPEKernel<T, U>::CopyOut(int32_t i) {
LocalTensor<T> outputUb = outQue.DeQue<T>(); LocalTensor<T> outputUb = outQue.DeQue<T>();
auto idy = i * st_ynt_ + _block_idx * st_ynh_; auto idy = i * st_ynt_ + _block_idx * st_ynh_;
DataCopyExtParams params = {1, static_cast<uint32_t>(_tile_len * sizeof(T)), 0, 0, 0}; DataCopyExtParams params = {1, static_cast<uint32_t>(_tile_len * sizeof(T)), 0, 0, 0};
...@@ -175,8 +175,8 @@ __aicore__ inline void RoPEKernel<T>::CopyOut(int32_t i) { ...@@ -175,8 +175,8 @@ __aicore__ inline void RoPEKernel<T>::CopyOut(int32_t i) {
outQue.FreeTensor(outputUb); outQue.FreeTensor(outputUb);
} }
template <typename T> template <typename T, typename U>
__aicore__ inline void RoPEKernel<T>::Process(int32_t nt) { __aicore__ inline void RoPEKernel<T, U>::Process(int32_t nt) {
for (int32_t i = 0; i < nt; ++i) { for (int32_t i = 0; i < nt; ++i) {
CopyIn(i); CopyIn(i);
...@@ -185,22 +185,73 @@ __aicore__ inline void RoPEKernel<T>::Process(int32_t nt) { ...@@ -185,22 +185,73 @@ __aicore__ inline void RoPEKernel<T>::Process(int32_t nt) {
} }
} }
#define ROPE_KERNEL(TYPE, POSTYPE) \
switch (POSTYPE) { \
case 3: { \
RoPEKernel<TYPE, int8_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
break; \
} \
case 4: { \
RoPEKernel<TYPE, int16_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
break; \
} \
case 5: { \
RoPEKernel<TYPE, int32_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
break; \
} \
case 6: { \
RoPEKernel<TYPE, int64_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
break; \
} \
case 7: { \
RoPEKernel<TYPE, uint8_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
break; \
} \
case 8: { \
RoPEKernel<TYPE, uint16_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
break; \
} \
case 9: { \
RoPEKernel<TYPE, uint32_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
break; \
} \
case 10: { \
RoPEKernel<TYPE, uint64_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
break; \
} \
}
__global__ __aicore__ void rope_f16_kernel(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM_ADDR sin, GM_ADDR cos, __global__ __aicore__ void rope_f16_kernel(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM_ADDR sin, GM_ADDR cos,
int32_t seq_len, int32_t dhead, int32_t seq_len, int32_t dhead,
int32_t y_stride_seqlen, int32_t y_stride_nhead, int32_t y_stride_seqlen, int32_t y_stride_nhead,
int32_t x_stride_seqlen, int32_t x_stride_nhead) { int32_t x_stride_seqlen, int32_t x_stride_nhead,
RoPEKernel<half> op; int32_t pos_type){
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); ROPE_KERNEL(half, pos_type)
op.Process(seq_len);
} }
__global__ __aicore__ void rope_f32_kernel(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM_ADDR sin, GM_ADDR cos, __global__ __aicore__ void rope_f32_kernel(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM_ADDR sin, GM_ADDR cos,
int32_t seq_len, int32_t dhead, int32_t seq_len, int32_t dhead,
int32_t y_stride_seqlen, int32_t y_stride_nhead, int32_t y_stride_seqlen, int32_t y_stride_nhead,
int32_t x_stride_seqlen, int32_t x_stride_nhead) { int32_t x_stride_seqlen, int32_t x_stride_nhead,
RoPEKernel<float> op; int32_t pos_type) {
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); ROPE_KERNEL(float, pos_type)
op.Process(seq_len);
} }
extern "C" infiniStatus_t rope_kernel_launch(void *y, extern "C" infiniStatus_t rope_kernel_launch(void *y,
...@@ -212,6 +263,7 @@ extern "C" infiniStatus_t rope_kernel_launch(void *y, ...@@ -212,6 +263,7 @@ extern "C" infiniStatus_t rope_kernel_launch(void *y,
int32_t nhead, int32_t nhead,
int32_t dhead, int32_t dhead,
int32_t data_type, int32_t data_type,
int32_t pos_type,
int32_t y_stride_seqlen, int32_t y_stride_seqlen,
int32_t y_stride_nhead, int32_t y_stride_nhead,
int32_t x_stride_seqlen, int32_t x_stride_seqlen,
...@@ -219,10 +271,10 @@ extern "C" infiniStatus_t rope_kernel_launch(void *y, ...@@ -219,10 +271,10 @@ extern "C" infiniStatus_t rope_kernel_launch(void *y,
void *stream) { void *stream) {
switch (data_type) { switch (data_type) {
case 12: // float16 case 12: // float16
rope_f16_kernel<<<nhead, nullptr, stream>>>(y, x, pos, sin, cos, seq_len, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); rope_f16_kernel<<<nhead, nullptr, stream>>>(y, x, pos, sin, cos, seq_len, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead, pos_type);
break; break;
case 13: // float32 case 13: // float32
rope_f32_kernel<<<nhead, nullptr, stream>>>(y, x, pos, sin, cos, seq_len, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); rope_f32_kernel<<<nhead, nullptr, stream>>>(y, x, pos, sin, cos, seq_len, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead, pos_type);
break; break;
default: default:
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment