Commit bbb0105b authored by zhangyunze's avatar zhangyunze Committed by zhangyue
Browse files

添加对非32字节对齐的输入形状支持

parent 00842e65
...@@ -48,8 +48,7 @@ infiniStatus_t Descriptor::calculate( ...@@ -48,8 +48,7 @@ infiniStatus_t Descriptor::calculate(
const void *sin_table, const void *sin_table,
const void *cos_table, const void *cos_table,
void *stream) const { void *stream) const {
// TODO: 是否强加这个判断 // TODO: 是否有可能解除这个判断
std::cout << "pos_type: " << _info.pos_type << std::endl;
CHECK_DTYPE(_info.pos_type, INFINI_DTYPE_U32); CHECK_DTYPE(_info.pos_type, INFINI_DTYPE_U32);
CHECK_DTYPE(_info.data_type, INFINI_DTYPE_F32, INFINI_DTYPE_F16); CHECK_DTYPE(_info.data_type, INFINI_DTYPE_F32, INFINI_DTYPE_F16);
int32_t seq_len = _info.seqlen; int32_t seq_len = _info.seqlen;
...@@ -60,9 +59,6 @@ infiniStatus_t Descriptor::calculate( ...@@ -60,9 +59,6 @@ infiniStatus_t Descriptor::calculate(
int32_t y_stride_nhead = _info.y_stride_nhead; int32_t y_stride_nhead = _info.y_stride_nhead;
int32_t x_stride_seqlen = _info.x_stride_seqlen; int32_t x_stride_seqlen = _info.x_stride_seqlen;
int32_t x_stride_nhead = _info.x_stride_nhead; int32_t x_stride_nhead = _info.x_stride_nhead;
std::cout << "shape is " << seq_len << ", " << nhead << ", " << dhead << std::endl;
std::cout << "y_stride is " << y_stride_seqlen << ", " << y_stride_nhead << ", 1" << std::endl;
std::cout << "x_stride is " << x_stride_seqlen << ", " << x_stride_nhead << ", 1" << std::endl;
return rope_kernel_launch(y, (void *)x, (void *)pos_ids, (void *)sin_table, (void *)cos_table, seq_len, nhead, dhead, data_type, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead, stream); return rope_kernel_launch(y, (void *)x, (void *)pos_ids, (void *)sin_table, (void *)cos_table, seq_len, nhead, dhead, data_type, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead, stream);
} }
} // namespace op::rope::ascend } // namespace op::rope::ascend
...@@ -49,6 +49,8 @@ private: ...@@ -49,6 +49,8 @@ private:
uint32_t _block_idx; uint32_t _block_idx;
uint32_t _tile_len; uint32_t _tile_len;
uint32_t _copy_len;
uint32_t _half_copy_len;
// stridey[st_ynt_, st_ynh_, 1] // stridey[st_ynt_, st_ynh_, 1]
int32_t st_ynt_; int32_t st_ynt_;
...@@ -68,6 +70,8 @@ __aicore__ inline void RoPEKernel<T>::Init(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM ...@@ -68,6 +70,8 @@ __aicore__ inline void RoPEKernel<T>::Init(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM
this->st_ynh_ = st_ynh; this->st_ynh_ = st_ynh;
this->st_xnt_ = st_xnt; this->st_xnt_ = st_xnt;
this->st_xnh_ = st_xnh; this->st_xnh_ = st_xnh;
_copy_len = (_tile_len * sizeof(T)) % BYTE_ALIGN == 0 ? _tile_len : (_tile_len * sizeof(T) + (BYTE_ALIGN - _tile_len * sizeof(T) % BYTE_ALIGN)) / sizeof(T);
_half_copy_len = (_tile_len / 2 * sizeof(T)) % BYTE_ALIGN == 0 ? _tile_len / 2 : (_tile_len / 2 * sizeof(T) + (BYTE_ALIGN - _tile_len / 2 * sizeof(T) % BYTE_ALIGN)) / sizeof(T);
_block_idx = GetBlockIdx(); _block_idx = GetBlockIdx();
...@@ -79,10 +83,10 @@ __aicore__ inline void RoPEKernel<T>::Init(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM ...@@ -79,10 +83,10 @@ __aicore__ inline void RoPEKernel<T>::Init(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM
yGm.SetGlobalBuffer((__gm__ T *)y); yGm.SetGlobalBuffer((__gm__ T *)y);
// Init Queue buffer // Init Queue buffer
pipe.InitBuffer(inQue, BUFFER_NUM, _tile_len * sizeof(T)); pipe.InitBuffer(inQue, BUFFER_NUM, _copy_len * sizeof(T));
pipe.InitBuffer(outQue, BUFFER_NUM, _tile_len * sizeof(T)); pipe.InitBuffer(outQue, BUFFER_NUM, _tile_len * sizeof(T));
pipe.InitBuffer(sinQue, BUFFER_NUM, _tile_len / 2 * sizeof(T)); pipe.InitBuffer(sinQue, BUFFER_NUM, _half_copy_len * sizeof(T));
pipe.InitBuffer(cosQue, BUFFER_NUM, _tile_len / 2 * sizeof(T)); pipe.InitBuffer(cosQue, BUFFER_NUM, _half_copy_len * sizeof(T));
pipe.InitBuffer(tmpOddBuf, _tile_len / 2 * sizeof(T)); pipe.InitBuffer(tmpOddBuf, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(tmpEvenBuf, _tile_len / 2 * sizeof(T)); pipe.InitBuffer(tmpEvenBuf, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(tmpOddBuf1, _tile_len / 2 * sizeof(T)); pipe.InitBuffer(tmpOddBuf1, _tile_len / 2 * sizeof(T));
...@@ -99,11 +103,11 @@ __aicore__ inline void RoPEKernel<T>::CopyIn(int32_t i) { ...@@ -99,11 +103,11 @@ __aicore__ inline void RoPEKernel<T>::CopyIn(int32_t i) {
// Get idx of current tile in total input // Get idx of current tile in total input
auto idx = i * st_xnt_ + _block_idx * st_xnh_; auto idx = i * st_xnt_ + _block_idx * st_xnh_;
// Copy tile current tile into UB // Copy tile current tile into UB
DataCopy(inputUb, xGm[idx], _tile_len); DataCopy(inputUb, xGm[idx], _copy_len);
// Copy sin cos tile // Copy sin cos tile
auto pos_idx = pGm(i); auto pos_idx = pGm(i);
DataCopy(sinUb, sinGm[pos_idx * _tile_len / 2], _tile_len / 2); DataCopy(sinUb, sinGm[pos_idx * _tile_len / 2], _half_copy_len);
DataCopy(cosUb, cosGm[pos_idx * _tile_len / 2], _tile_len / 2); DataCopy(cosUb, cosGm[pos_idx * _tile_len / 2], _half_copy_len);
// Push in operands // Push in operands
inQue.EnQue(inputUb); inQue.EnQue(inputUb);
sinQue.EnQue(sinUb); sinQue.EnQue(sinUb);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment