Commit 2c0e9a6e authored by zhangyunze's avatar zhangyunze Committed by zhangyue
Browse files

fix:解除kernel中对postype的限制

parent bbb0105b
......@@ -33,6 +33,7 @@ extern "C" infiniStatus_t rope_kernel_launch(
int32_t nhead,
int32_t dhead,
int32_t data_type,
int32_t pos_type,
int32_t y_stride_seqlen,
int32_t y_stride_nhead,
int32_t x_stride_seqlen,
......@@ -48,17 +49,16 @@ infiniStatus_t Descriptor::calculate(
const void *sin_table,
const void *cos_table,
void *stream) const {
// TODO: 是否有可能解除这个判断
CHECK_DTYPE(_info.pos_type, INFINI_DTYPE_U32);
CHECK_DTYPE(_info.data_type, INFINI_DTYPE_F32, INFINI_DTYPE_F16);
int32_t seq_len = _info.seqlen;
int32_t nhead = _info.nhead;
int32_t dhead = _info.dhead;
int32_t data_type = _info.data_type;
int32_t pos_type = _info.pos_type;
int32_t y_stride_seqlen = _info.y_stride_seqlen;
int32_t y_stride_nhead = _info.y_stride_nhead;
int32_t x_stride_seqlen = _info.x_stride_seqlen;
int32_t x_stride_nhead = _info.x_stride_nhead;
return rope_kernel_launch(y, (void *)x, (void *)pos_ids, (void *)sin_table, (void *)cos_table, seq_len, nhead, dhead, data_type, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead, stream);
return rope_kernel_launch(y, (void *)x, (void *)pos_ids, (void *)sin_table, (void *)cos_table, seq_len, nhead, dhead, data_type, pos_type, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead, stream);
}
} // namespace op::rope::ascend
......@@ -7,7 +7,7 @@ constexpr int32_t BYTE_ALIGN = 32;
using namespace AscendC;
template <typename T>
template <typename T, typename U>
class RoPEKernel {
public:
__aicore__ inline RoPEKernel() {}
......@@ -43,7 +43,7 @@ private:
TBuf<TPosition::VECCALC> tmpEvenBuf2;
GlobalTensor<T> xGm, yGm;
GlobalTensor<uint32_t> pGm;
GlobalTensor<U> pGm;
GlobalTensor<T> sinGm;
GlobalTensor<T> cosGm;
......@@ -60,11 +60,11 @@ private:
int32_t st_xnh_;
};
template <typename T>
__aicore__ inline void RoPEKernel<T>::Init(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM_ADDR sin, GM_ADDR cos,
int32_t dh,
int32_t st_ynt, int32_t st_ynh,
int32_t st_xnt, int32_t st_xnh) {
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::Init(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM_ADDR sin, GM_ADDR cos,
int32_t dh,
int32_t st_ynt, int32_t st_ynh,
int32_t st_xnt, int32_t st_xnh) {
this->_tile_len = dh;
this->st_ynt_ = st_ynt;
this->st_ynh_ = st_ynh;
......@@ -77,7 +77,7 @@ __aicore__ inline void RoPEKernel<T>::Init(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM
// Init global buffer
xGm.SetGlobalBuffer((__gm__ T *)x);
pGm.SetGlobalBuffer(reinterpret_cast<__gm__ uint32_t *>(pos));
pGm.SetGlobalBuffer((__gm__ U *)pos);
sinGm.SetGlobalBuffer((__gm__ T *)sin);
cosGm.SetGlobalBuffer((__gm__ T *)cos);
yGm.SetGlobalBuffer((__gm__ T *)y);
......@@ -95,8 +95,8 @@ __aicore__ inline void RoPEKernel<T>::Init(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM
pipe.InitBuffer(tmpEvenBuf2, _tile_len / 2 * sizeof(T));
}
template <typename T>
__aicore__ inline void RoPEKernel<T>::CopyIn(int32_t i) {
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::CopyIn(int32_t i) {
LocalTensor<T> inputUb = inQue.AllocTensor<T>();
LocalTensor<T> sinUb = sinQue.AllocTensor<T>();
LocalTensor<T> cosUb = cosQue.AllocTensor<T>();
......@@ -114,8 +114,8 @@ __aicore__ inline void RoPEKernel<T>::CopyIn(int32_t i) {
cosQue.EnQue(cosUb);
}
template <typename T>
__aicore__ inline void RoPEKernel<T>::Compute(int32_t i) {
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::Compute(int32_t i) {
LocalTensor<T> inputUb = inQue.DeQue<T>();
LocalTensor<T> sinUb = sinQue.DeQue<T>();
LocalTensor<T> cosUb = cosQue.DeQue<T>();
......@@ -166,8 +166,8 @@ __aicore__ inline void RoPEKernel<T>::Compute(int32_t i) {
cosQue.FreeTensor(cosUb);
}
template <typename T>
__aicore__ inline void RoPEKernel<T>::CopyOut(int32_t i) {
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::CopyOut(int32_t i) {
LocalTensor<T> outputUb = outQue.DeQue<T>();
auto idy = i * st_ynt_ + _block_idx * st_ynh_;
DataCopyExtParams params = {1, static_cast<uint32_t>(_tile_len * sizeof(T)), 0, 0, 0};
......@@ -175,8 +175,8 @@ __aicore__ inline void RoPEKernel<T>::CopyOut(int32_t i) {
outQue.FreeTensor(outputUb);
}
template <typename T>
__aicore__ inline void RoPEKernel<T>::Process(int32_t nt) {
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::Process(int32_t nt) {
for (int32_t i = 0; i < nt; ++i) {
CopyIn(i);
......@@ -185,22 +185,73 @@ __aicore__ inline void RoPEKernel<T>::Process(int32_t nt) {
}
}
#define ROPE_KERNEL(TYPE, POSTYPE) \
switch (POSTYPE) { \
case 3: { \
RoPEKernel<TYPE, int8_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
break; \
} \
case 4: { \
RoPEKernel<TYPE, int16_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
break; \
} \
case 5: { \
RoPEKernel<TYPE, int32_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
break; \
} \
case 6: { \
RoPEKernel<TYPE, int64_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
break; \
} \
case 7: { \
RoPEKernel<TYPE, uint8_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
break; \
} \
case 8: { \
RoPEKernel<TYPE, uint16_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
break; \
} \
case 9: { \
RoPEKernel<TYPE, uint32_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
break; \
} \
case 10: { \
RoPEKernel<TYPE, uint64_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
break; \
} \
}
__global__ __aicore__ void rope_f16_kernel(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM_ADDR sin, GM_ADDR cos,
int32_t seq_len, int32_t dhead,
int32_t y_stride_seqlen, int32_t y_stride_nhead,
int32_t x_stride_seqlen, int32_t x_stride_nhead) {
RoPEKernel<half> op;
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead);
op.Process(seq_len);
int32_t x_stride_seqlen, int32_t x_stride_nhead,
int32_t pos_type){
ROPE_KERNEL(half, pos_type)
}
__global__ __aicore__ void rope_f32_kernel(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM_ADDR sin, GM_ADDR cos,
int32_t seq_len, int32_t dhead,
int32_t y_stride_seqlen, int32_t y_stride_nhead,
int32_t x_stride_seqlen, int32_t x_stride_nhead) {
RoPEKernel<float> op;
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead);
op.Process(seq_len);
int32_t x_stride_seqlen, int32_t x_stride_nhead,
int32_t pos_type) {
ROPE_KERNEL(float, pos_type)
}
extern "C" infiniStatus_t rope_kernel_launch(void *y,
......@@ -212,6 +263,7 @@ extern "C" infiniStatus_t rope_kernel_launch(void *y,
int32_t nhead,
int32_t dhead,
int32_t data_type,
int32_t pos_type,
int32_t y_stride_seqlen,
int32_t y_stride_nhead,
int32_t x_stride_seqlen,
......@@ -219,10 +271,10 @@ extern "C" infiniStatus_t rope_kernel_launch(void *y,
void *stream) {
switch (data_type) {
case 12: // float16
rope_f16_kernel<<<nhead, nullptr, stream>>>(y, x, pos, sin, cos, seq_len, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead);
rope_f16_kernel<<<nhead, nullptr, stream>>>(y, x, pos, sin, cos, seq_len, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead, pos_type);
break;
case 13: // float32
rope_f32_kernel<<<nhead, nullptr, stream>>>(y, x, pos, sin, cos, seq_len, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead);
rope_f32_kernel<<<nhead, nullptr, stream>>>(y, x, pos, sin, cos, seq_len, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead, pos_type);
break;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment