Commit bbb0105b authored by zhangyunze's avatar zhangyunze Committed by zhangyue
Browse files

添加对非32字节对齐的输入形状支持

parent 00842e65
......@@ -48,8 +48,7 @@ infiniStatus_t Descriptor::calculate(
const void *sin_table,
const void *cos_table,
void *stream) const {
// TODO: 是否强加这个判断
std::cout << "pos_type: " << _info.pos_type << std::endl;
// TODO: 是否有可能解除这个判断
CHECK_DTYPE(_info.pos_type, INFINI_DTYPE_U32);
CHECK_DTYPE(_info.data_type, INFINI_DTYPE_F32, INFINI_DTYPE_F16);
int32_t seq_len = _info.seqlen;
......@@ -60,9 +59,6 @@ infiniStatus_t Descriptor::calculate(
int32_t y_stride_nhead = _info.y_stride_nhead;
int32_t x_stride_seqlen = _info.x_stride_seqlen;
int32_t x_stride_nhead = _info.x_stride_nhead;
std::cout << "shape is " << seq_len << ", " << nhead << ", " << dhead << std::endl;
std::cout << "y_stride is " << y_stride_seqlen << ", " << y_stride_nhead << ", 1" << std::endl;
std::cout << "x_stride is " << x_stride_seqlen << ", " << x_stride_nhead << ", 1" << std::endl;
return rope_kernel_launch(y, (void *)x, (void *)pos_ids, (void *)sin_table, (void *)cos_table, seq_len, nhead, dhead, data_type, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead, stream);
}
} // namespace op::rope::ascend
......@@ -49,6 +49,8 @@ private:
uint32_t _block_idx;
uint32_t _tile_len;
uint32_t _copy_len;
uint32_t _half_copy_len;
// stridey[st_ynt_, st_ynh_, 1]
int32_t st_ynt_;
......@@ -68,6 +70,8 @@ __aicore__ inline void RoPEKernel<T>::Init(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM
this->st_ynh_ = st_ynh;
this->st_xnt_ = st_xnt;
this->st_xnh_ = st_xnh;
_copy_len = (_tile_len * sizeof(T)) % BYTE_ALIGN == 0 ? _tile_len : (_tile_len * sizeof(T) + (BYTE_ALIGN - _tile_len * sizeof(T) % BYTE_ALIGN)) / sizeof(T);
_half_copy_len = (_tile_len / 2 * sizeof(T)) % BYTE_ALIGN == 0 ? _tile_len / 2 : (_tile_len / 2 * sizeof(T) + (BYTE_ALIGN - _tile_len / 2 * sizeof(T) % BYTE_ALIGN)) / sizeof(T);
_block_idx = GetBlockIdx();
......@@ -79,10 +83,10 @@ __aicore__ inline void RoPEKernel<T>::Init(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM
yGm.SetGlobalBuffer((__gm__ T *)y);
// Init Queue buffer
pipe.InitBuffer(inQue, BUFFER_NUM, _tile_len * sizeof(T));
pipe.InitBuffer(inQue, BUFFER_NUM, _copy_len * sizeof(T));
pipe.InitBuffer(outQue, BUFFER_NUM, _tile_len * sizeof(T));
pipe.InitBuffer(sinQue, BUFFER_NUM, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(cosQue, BUFFER_NUM, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(sinQue, BUFFER_NUM, _half_copy_len * sizeof(T));
pipe.InitBuffer(cosQue, BUFFER_NUM, _half_copy_len * sizeof(T));
pipe.InitBuffer(tmpOddBuf, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(tmpEvenBuf, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(tmpOddBuf1, _tile_len / 2 * sizeof(T));
......@@ -99,11 +103,11 @@ __aicore__ inline void RoPEKernel<T>::CopyIn(int32_t i) {
// Get idx of current tile in total input
auto idx = i * st_xnt_ + _block_idx * st_xnh_;
// Copy tile current tile into UB
DataCopy(inputUb, xGm[idx], _tile_len);
DataCopy(inputUb, xGm[idx], _copy_len);
// Copy sin cos tile
auto pos_idx = pGm(i);
DataCopy(sinUb, sinGm[pos_idx * _tile_len / 2], _tile_len / 2);
DataCopy(cosUb, cosGm[pos_idx * _tile_len / 2], _tile_len / 2);
DataCopy(sinUb, sinGm[pos_idx * _tile_len / 2], _half_copy_len);
DataCopy(cosUb, cosGm[pos_idx * _tile_len / 2], _half_copy_len);
// Push in operands
inQue.EnQue(inputUb);
sinQue.EnQue(sinUb);
......@@ -225,4 +229,4 @@ extern "C" infiniStatus_t rope_kernel_launch(void *y,
}
return INFINI_STATUS_SUCCESS;
}
\ No newline at end of file
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment