Commit 8727bdcf authored by zhangyue's avatar zhangyue
Browse files

rename private vars

parent 30bf79f1
......@@ -32,21 +32,21 @@ private:
private:
TPipe pipe;
TQue<QuePosition::VECIN, BUFFER_NUM> inQue;
TQue<QuePosition::VECIN, BUFFER_NUM> sinQue;
TQue<QuePosition::VECIN, BUFFER_NUM> cosQue;
TQue<QuePosition::VECOUT, BUFFER_NUM> outQue;
TBuf<TPosition::VECCALC> tmpOddBuf;
TBuf<TPosition::VECCALC> tmpEvenBuf;
TBuf<TPosition::VECCALC> tmpOddBuf1;
TBuf<TPosition::VECCALC> tmpOddBuf2;
TBuf<TPosition::VECCALC> tmpEvenBuf1;
TBuf<TPosition::VECCALC> tmpEvenBuf2;
TQue<QuePosition::VECIN, BUFFER_NUM> _in_que;
TQue<QuePosition::VECIN, BUFFER_NUM> _sin_que;
TQue<QuePosition::VECIN, BUFFER_NUM> _cos_que;
TQue<QuePosition::VECOUT, BUFFER_NUM> _out_que;
TBuf<TPosition::VECCALC> _tmp_odd_buf;
TBuf<TPosition::VECCALC> _tmp_even_buf;
TBuf<TPosition::VECCALC> _tmp_odd_buf1;
TBuf<TPosition::VECCALC> _tmp_odd_buf2;
TBuf<TPosition::VECCALC> _tmp_even_buf1;
TBuf<TPosition::VECCALC> _tmp_even_buf2;
GlobalTensor<T> xGm, yGm;
GlobalTensor<U> pGm;
GlobalTensor<T> sinGm;
GlobalTensor<T> cosGm;
GlobalTensor<T> _x_gm, _y_gm;
GlobalTensor<U> _p_gm;
GlobalTensor<T> _sin_gm;
GlobalTensor<T> _cos_gm;
size_t _block_idx;
size_t _tile_len;
......@@ -83,57 +83,57 @@ __aicore__ inline void RoPEKernel<T, U>::init(GM_ADDR y,
_block_idx = GetBlockIdx();
// Init global buffer
xGm.SetGlobalBuffer((__gm__ T *)x);
pGm.SetGlobalBuffer((__gm__ U *)pos);
sinGm.SetGlobalBuffer((__gm__ T *)sin);
cosGm.SetGlobalBuffer((__gm__ T *)cos);
yGm.SetGlobalBuffer((__gm__ T *)y);
_x_gm.SetGlobalBuffer((__gm__ T *)x);
_p_gm.SetGlobalBuffer((__gm__ U *)pos);
_sin_gm.SetGlobalBuffer((__gm__ T *)sin);
_cos_gm.SetGlobalBuffer((__gm__ T *)cos);
_y_gm.SetGlobalBuffer((__gm__ T *)y);
// Init Queue buffer
pipe.InitBuffer(inQue, BUFFER_NUM, _copy_len * sizeof(T));
pipe.InitBuffer(outQue, BUFFER_NUM, _tile_len * sizeof(T));
pipe.InitBuffer(sinQue, BUFFER_NUM, _half_copy_len * sizeof(T));
pipe.InitBuffer(cosQue, BUFFER_NUM, _half_copy_len * sizeof(T));
pipe.InitBuffer(tmpOddBuf, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(tmpEvenBuf, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(tmpOddBuf1, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(tmpOddBuf2, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(tmpEvenBuf1, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(tmpEvenBuf2, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(_in_que, BUFFER_NUM, _copy_len * sizeof(T));
pipe.InitBuffer(_out_que, BUFFER_NUM, _tile_len * sizeof(T));
pipe.InitBuffer(_sin_que, BUFFER_NUM, _half_copy_len * sizeof(T));
pipe.InitBuffer(_cos_que, BUFFER_NUM, _half_copy_len * sizeof(T));
pipe.InitBuffer(_tmp_odd_buf, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(_tmp_even_buf, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(_tmp_odd_buf1, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(_tmp_odd_buf2, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(_tmp_even_buf1, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(_tmp_even_buf2, _tile_len / 2 * sizeof(T));
}
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::copyIn(size_t i) {
LocalTensor<T> inputUb = inQue.AllocTensor<T>();
LocalTensor<T> sinUb = sinQue.AllocTensor<T>();
LocalTensor<T> cosUb = cosQue.AllocTensor<T>();
LocalTensor<T> input_ub = _in_que.AllocTensor<T>();
LocalTensor<T> sin_ub = _sin_que.AllocTensor<T>();
LocalTensor<T> cos_ub = _cos_que.AllocTensor<T>();
// Get idx of current tile in total input
auto idx = i * _st_xnt + _block_idx * _st_xnh;
// Copy tile current tile into UB
DataCopy(inputUb, xGm[idx], _copy_len);
DataCopy(input_ub, _x_gm[idx], _copy_len);
// Copy sin cos tile
auto pos_idx = pGm(i);
DataCopy(sinUb, sinGm[pos_idx * _tile_len / 2], _half_copy_len);
DataCopy(cosUb, cosGm[pos_idx * _tile_len / 2], _half_copy_len);
auto pos_idx = _p_gm(i);
DataCopy(sin_ub, _sin_gm[pos_idx * _tile_len / 2], _half_copy_len);
DataCopy(cos_ub, _cos_gm[pos_idx * _tile_len / 2], _half_copy_len);
// Push in operands
inQue.EnQue(inputUb);
sinQue.EnQue(sinUb);
cosQue.EnQue(cosUb);
_in_que.EnQue(input_ub);
_sin_que.EnQue(sin_ub);
_cos_que.EnQue(cos_ub);
}
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::compute(size_t i) {
LocalTensor<T> inputUb = inQue.DeQue<T>();
LocalTensor<T> sinUb = sinQue.DeQue<T>();
LocalTensor<T> cosUb = cosQue.DeQue<T>();
LocalTensor<T> outputUb = outQue.AllocTensor<T>();
LocalTensor<T> input_ub = _in_que.DeQue<T>();
LocalTensor<T> sin_ub = _sin_que.DeQue<T>();
LocalTensor<T> cos_ub = _cos_que.DeQue<T>();
LocalTensor<T> output_ub = _out_que.AllocTensor<T>();
LocalTensor<T> tmpOdd = tmpOddBuf.Get<T>();
LocalTensor<T> tmpEven = tmpEvenBuf.Get<T>();
LocalTensor<T> tmpOdd1 = tmpOddBuf1.Get<T>();
LocalTensor<T> tmpOdd2 = tmpOddBuf2.Get<T>();
LocalTensor<T> tmpEven1 = tmpEvenBuf1.Get<T>();
LocalTensor<T> tmpEven2 = tmpEvenBuf2.Get<T>();
LocalTensor<T> tmp_odd = _tmp_odd_buf.Get<T>();
LocalTensor<T> tmp_even = _tmp_even_buf.Get<T>();
LocalTensor<T> tmp_odd1 = _tmp_odd_buf1.Get<T>();
LocalTensor<T> tmp_odd2 = _tmp_odd_buf2.Get<T>();
LocalTensor<T> tmp_even1 = _tmp_even_buf1.Get<T>();
LocalTensor<T> tmp_even2 = _tmp_even_buf2.Get<T>();
// separate odd and even bit elements
uint64_t rsvdCnt = 0;
......@@ -143,43 +143,43 @@ __aicore__ inline void RoPEKernel<T, U>::compute(size_t i) {
8,
8,
};
GatherMask<T>(tmpOdd, inputUb, 1, false, 0, gMaskParams, rsvdCnt);
GatherMask<T>(tmpEven, inputUb, 2, false, 0, gMaskParams, rsvdCnt);
GatherMask<T>(tmp_odd, input_ub, 1, false, 0, gMaskParams, rsvdCnt);
GatherMask<T>(tmp_even, input_ub, 2, false, 0, gMaskParams, rsvdCnt);
PipeBarrier<PIPE_V>();
// compute odd bit elements
// y_odd = x_odd * cos - x_even * sin
Mul<T>(tmpOdd1, tmpOdd, cosUb, _tile_len / 2);
Mul<T>(tmpOdd2, tmpEven, sinUb, _tile_len / 2);
Mul<T>(tmp_odd1, tmp_odd, cos_ub, _tile_len / 2);
Mul<T>(tmp_odd2, tmp_even, sin_ub, _tile_len / 2);
PipeBarrier<PIPE_V>();
Sub<T>(tmpOdd1, tmpOdd1, tmpOdd2, _tile_len / 2);
Sub<T>(tmp_odd1, tmp_odd1, tmp_odd2, _tile_len / 2);
// compute even bit elements
// y_even = x_odd * sin + x_even * cos
Mul<T>(tmpEven1, tmpOdd, sinUb, _tile_len / 2);
Mul<T>(tmpEven2, tmpEven, cosUb, _tile_len / 2);
Mul<T>(tmp_even1, tmp_odd, sin_ub, _tile_len / 2);
Mul<T>(tmp_even2, tmp_even, cos_ub, _tile_len / 2);
PipeBarrier<PIPE_V>();
Add<T>(tmpEven1, tmpEven1, tmpEven2, _tile_len / 2);
Add<T>(tmp_even1, tmp_even1, tmp_even2, _tile_len / 2);
// combine odd and even bit elements
for (uint32_t j = 0; j < _tile_len / 2; j += 1) {
outputUb(j * 2) = tmpOdd1(j);
outputUb(j * 2 + 1) = tmpEven1(j);
output_ub(j * 2) = tmp_odd1(j);
output_ub(j * 2 + 1) = tmp_even1(j);
}
outQue.EnQue<T>(outputUb);
inQue.FreeTensor(inputUb);
sinQue.FreeTensor(sinUb);
cosQue.FreeTensor(cosUb);
_out_que.EnQue<T>(output_ub);
_in_que.FreeTensor(input_ub);
_sin_que.FreeTensor(sin_ub);
_cos_que.FreeTensor(cos_ub);
}
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::copyOut(size_t i) {
LocalTensor<T> outputUb = outQue.DeQue<T>();
LocalTensor<T> output_ub = _out_que.DeQue<T>();
auto idy = i * _st_ynt + _block_idx * _st_ynh;
DataCopyExtParams params = {1, static_cast<uint32_t>(_tile_len * sizeof(T)), 0, 0, 0};
DataCopyPad(yGm[idy], outputUb, params);
outQue.FreeTensor(outputUb);
DataCopyPad(_y_gm[idy], output_ub, params);
_out_que.FreeTensor(output_ub);
}
template <typename T, typename U>
......@@ -192,56 +192,30 @@ __aicore__ inline void RoPEKernel<T, U>::process(size_t seq_len) {
}
}
#define ROPE_KERNEL(TYPE, POSTYPE) \
switch (POSTYPE) { \
case INFINI_DTYPE_I8: { \
RoPEKernel<TYPE, int8_t> op; \
op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.process(seq_len); \
break; \
} \
case INFINI_DTYPE_I16: { \
RoPEKernel<TYPE, int16_t> op; \
op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.process(seq_len); \
break; \
} \
case INFINI_DTYPE_I32: { \
RoPEKernel<TYPE, int32_t> op; \
op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.process(seq_len); \
break; \
} \
case INFINI_DTYPE_I64: { \
RoPEKernel<TYPE, int64_t> op; \
op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.process(seq_len); \
break; \
} \
case INFINI_DTYPE_U8: { \
RoPEKernel<TYPE, uint8_t> op; \
op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.process(seq_len); \
break; \
} \
case INFINI_DTYPE_U16: { \
RoPEKernel<TYPE, uint16_t> op; \
op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.process(seq_len); \
break; \
} \
case INFINI_DTYPE_U32: { \
RoPEKernel<TYPE, uint32_t> op; \
op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.process(seq_len); \
break; \
} \
case INFINI_DTYPE_U64: { \
RoPEKernel<TYPE, uint64_t> op; \
op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.process(seq_len); \
break; \
} \
#define ROPE_KERNEL_INIT_ARGS y, x, pos, sin, cos, dhead, \
y_stride_seqlen, y_stride_nhead, \
x_stride_seqlen, x_stride_nhead
#define CASE_POSTYPE(POS_TYPE_ENUM, TYPE, POS_T) \
case POS_TYPE_ENUM: { \
RoPEKernel<TYPE, POS_T> op; \
op.init(ROPE_KERNEL_INIT_ARGS); \
op.process(seq_len); \
break; \
}
#define ROPE_KERNEL(TYPE, POSTYPE) \
switch (POSTYPE) { \
CASE_POSTYPE(INFINI_DTYPE_I8, TYPE, int8_t) \
CASE_POSTYPE(INFINI_DTYPE_I16, TYPE, int16_t) \
CASE_POSTYPE(INFINI_DTYPE_I32, TYPE, int32_t) \
CASE_POSTYPE(INFINI_DTYPE_I64, TYPE, int64_t) \
CASE_POSTYPE(INFINI_DTYPE_U8, TYPE, uint8_t) \
CASE_POSTYPE(INFINI_DTYPE_U16, TYPE, uint16_t) \
CASE_POSTYPE(INFINI_DTYPE_U32, TYPE, uint32_t) \
CASE_POSTYPE(INFINI_DTYPE_U64, TYPE, uint64_t) \
default: \
break; \
}
#define DEFINE_ROPE_KERNEL(KERNEL_NAME, TYPE) \
......@@ -264,6 +238,9 @@ DEFINE_ROPE_KERNEL(rope_kernel_float, float)
DEFINE_ROPE_KERNEL(rope_kernel_half, half)
#undef DEFINE_ROPE_KERNEL
#undef ROPE_KERNEL
#undef CASE_POSTYPE
#undef ROPE_KERNEL_INIT_ARGS
extern "C" infiniStatus_t rope_kernel_launch(
void *y,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment