Unverified Commit 1be004cb authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #203 from InfiniTensor/ascend-rope

feat: 添加昇腾rope算子
parents 5beab8c0 8727bdcf
...@@ -23,10 +23,9 @@ include_directories( ...@@ -23,10 +23,9 @@ include_directories(
${CMAKE_SOURCE_DIR}/../../../../include/infiniop/ ${CMAKE_SOURCE_DIR}/../../../../include/infiniop/
) )
ascendc_library(ascend_kernels STATIC ascendc_library(ascend_kernels STATIC
../../ops/swiglu/ascend/swiglu_ascend_kernel.cpp ../../ops/swiglu/ascend/swiglu_ascend_kernel.cpp
# ../../ops/rotary_embedding/ascend/rotary_embedding_kernel.cpp ../../ops/rope/ascend/rope_ascend_kernel.cpp
# ../../ops/random_sample/ascend/random_sample_kernel.cpp # ../../ops/random_sample/ascend/random_sample_kernel.cpp
) )
...@@ -4,8 +4,17 @@ ...@@ -4,8 +4,17 @@
#include "../../../../include/infinicore.h" #include "../../../../include/infinicore.h"
#include "kernel_operator.h" #include "kernel_operator.h"
constexpr int32_t BLOCK_NUM = 8; constexpr size_t BLOCK_NUM = 8;
constexpr int32_t BUFFER_NUM = 2; constexpr size_t BUFFER_NUM = 2;
constexpr int32_t BYTE_ALIGN = 32; constexpr size_t BYTE_ALIGN = 32;
template <typename T>
__aicore__ inline size_t alignTileLen(size_t tile_len, size_t byte_align) {
size_t bytes = tile_len * sizeof(T);
size_t aligned_bytes = (bytes % byte_align == 0)
? bytes
: (bytes + (byte_align - bytes % byte_align));
return aligned_bytes / sizeof(T);
}
#endif #endif
#include "rope_ascend.h"
#include "../../../devices/ascend/common_ascend.h"
namespace op::rope::ascend {
Descriptor::~Descriptor()
= default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
auto handle_ascned = reinterpret_cast<device::ascend::Handle *>(handle);
auto result = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
CHECK_RESULT(result);
size_t workspace_size = 0;
*desc_ptr = new Descriptor(std::move(result.take()), workspace_size, nullptr, handle_ascned->device, handle_ascned->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *y,
const void *x,
const void *pos_ids,
const void *sin_table,
const void *cos_table,
void *stream) const {
CHECK_DTYPE(_info.data_type, INFINI_DTYPE_F32, INFINI_DTYPE_F16);
auto data_type = _info.data_type;
auto pos_type = _info.pos_type;
auto seq_len = _info.seqlen;
auto nhead = _info.nhead;
auto dhead = _info.dhead;
auto y_stride_seqlen = _info.y_stride_seqlen;
auto y_stride_nhead = _info.y_stride_nhead;
auto x_stride_seqlen = _info.x_stride_seqlen;
auto x_stride_nhead = _info.x_stride_nhead;
return rope_kernel_launch(y, (void *)x, (void *)pos_ids, (void *)sin_table, (void *)cos_table, seq_len, nhead, dhead, data_type, pos_type, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead, stream);
}
} // namespace op::rope::ascend
#ifndef __ACLNN_ROPE_H__
#define __ACLNN_ROPE_H__
#include "../rope.h"
extern "C" infiniStatus_t rope_kernel_launch(
void *y,
void *x,
void *pos,
void *sin,
void *cos,
size_t seq_len,
size_t nhead,
size_t dhead,
infiniDtype_t data_type,
infiniDtype_t pos_type,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead,
void *stream);
DESCRIPTOR(ascend)
#endif // __ACLNN_ROPE_H__
#include "../../../devices/ascend/ascend_kernel_common.h"
using namespace AscendC;
template <typename T, typename U>
class RoPEKernel {
public:
__aicore__ inline RoPEKernel() {}
// Init op
// pos position vector
// x input tensor
// y output tensor
// tensor shape [nt, nh, dh]
// make block_num = nh, tile_len = dh
__aicore__ inline void init(GM_ADDR y,
GM_ADDR x,
GM_ADDR pos,
GM_ADDR sin,
GM_ADDR cos,
size_t dh,
ptrdiff_t st_ynt,
ptrdiff_t st_ynh,
ptrdiff_t st_xnt,
ptrdiff_t st_xnh);
__aicore__ inline void process(size_t seq_len);
private:
// Copy a tile into UB
__aicore__ inline void copyIn(size_t i);
__aicore__ inline void compute(size_t i);
__aicore__ inline void copyOut(size_t i);
private:
TPipe pipe;
TQue<QuePosition::VECIN, BUFFER_NUM> _in_que;
TQue<QuePosition::VECIN, BUFFER_NUM> _sin_que;
TQue<QuePosition::VECIN, BUFFER_NUM> _cos_que;
TQue<QuePosition::VECOUT, BUFFER_NUM> _out_que;
TBuf<TPosition::VECCALC> _tmp_odd_buf;
TBuf<TPosition::VECCALC> _tmp_even_buf;
TBuf<TPosition::VECCALC> _tmp_odd_buf1;
TBuf<TPosition::VECCALC> _tmp_odd_buf2;
TBuf<TPosition::VECCALC> _tmp_even_buf1;
TBuf<TPosition::VECCALC> _tmp_even_buf2;
GlobalTensor<T> _x_gm, _y_gm;
GlobalTensor<U> _p_gm;
GlobalTensor<T> _sin_gm;
GlobalTensor<T> _cos_gm;
size_t _block_idx;
size_t _tile_len;
size_t _copy_len;
size_t _half_copy_len;
// stridey[_st_ynt, _st_ynh, 1]
ptrdiff_t _st_ynt;
ptrdiff_t _st_ynh;
// stridex[_st_xnt, _st_xnh, 1]
ptrdiff_t _st_xnt;
ptrdiff_t _st_xnh;
};
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::init(GM_ADDR y,
GM_ADDR x,
GM_ADDR pos,
GM_ADDR sin,
GM_ADDR cos,
size_t dh,
ptrdiff_t st_ynt,
ptrdiff_t st_ynh,
ptrdiff_t st_xnt,
ptrdiff_t st_xnh) {
this->_tile_len = dh;
this->_st_ynt = st_ynt;
this->_st_ynh = st_ynh;
this->_st_xnt = st_xnt;
this->_st_xnh = st_xnh;
_copy_len = alignTileLen<T>(dh, BYTE_ALIGN);
_half_copy_len = alignTileLen<T>(dh, BYTE_ALIGN);
_block_idx = GetBlockIdx();
// Init global buffer
_x_gm.SetGlobalBuffer((__gm__ T *)x);
_p_gm.SetGlobalBuffer((__gm__ U *)pos);
_sin_gm.SetGlobalBuffer((__gm__ T *)sin);
_cos_gm.SetGlobalBuffer((__gm__ T *)cos);
_y_gm.SetGlobalBuffer((__gm__ T *)y);
// Init Queue buffer
pipe.InitBuffer(_in_que, BUFFER_NUM, _copy_len * sizeof(T));
pipe.InitBuffer(_out_que, BUFFER_NUM, _tile_len * sizeof(T));
pipe.InitBuffer(_sin_que, BUFFER_NUM, _half_copy_len * sizeof(T));
pipe.InitBuffer(_cos_que, BUFFER_NUM, _half_copy_len * sizeof(T));
pipe.InitBuffer(_tmp_odd_buf, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(_tmp_even_buf, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(_tmp_odd_buf1, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(_tmp_odd_buf2, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(_tmp_even_buf1, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(_tmp_even_buf2, _tile_len / 2 * sizeof(T));
}
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::copyIn(size_t i) {
LocalTensor<T> input_ub = _in_que.AllocTensor<T>();
LocalTensor<T> sin_ub = _sin_que.AllocTensor<T>();
LocalTensor<T> cos_ub = _cos_que.AllocTensor<T>();
// Get idx of current tile in total input
auto idx = i * _st_xnt + _block_idx * _st_xnh;
// Copy tile current tile into UB
DataCopy(input_ub, _x_gm[idx], _copy_len);
// Copy sin cos tile
auto pos_idx = _p_gm(i);
DataCopy(sin_ub, _sin_gm[pos_idx * _tile_len / 2], _half_copy_len);
DataCopy(cos_ub, _cos_gm[pos_idx * _tile_len / 2], _half_copy_len);
// Push in operands
_in_que.EnQue(input_ub);
_sin_que.EnQue(sin_ub);
_cos_que.EnQue(cos_ub);
}
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::compute(size_t i) {
LocalTensor<T> input_ub = _in_que.DeQue<T>();
LocalTensor<T> sin_ub = _sin_que.DeQue<T>();
LocalTensor<T> cos_ub = _cos_que.DeQue<T>();
LocalTensor<T> output_ub = _out_que.AllocTensor<T>();
LocalTensor<T> tmp_odd = _tmp_odd_buf.Get<T>();
LocalTensor<T> tmp_even = _tmp_even_buf.Get<T>();
LocalTensor<T> tmp_odd1 = _tmp_odd_buf1.Get<T>();
LocalTensor<T> tmp_odd2 = _tmp_odd_buf2.Get<T>();
LocalTensor<T> tmp_even1 = _tmp_even_buf1.Get<T>();
LocalTensor<T> tmp_even2 = _tmp_even_buf2.Get<T>();
// separate odd and even bit elements
uint64_t rsvdCnt = 0;
GatherMaskParams gMaskParams = {
1,
static_cast<uint16_t>((_tile_len * sizeof(T) + 255) / 256), // no more than 256(<=255)
8,
8,
};
GatherMask<T>(tmp_odd, input_ub, 1, false, 0, gMaskParams, rsvdCnt);
GatherMask<T>(tmp_even, input_ub, 2, false, 0, gMaskParams, rsvdCnt);
PipeBarrier<PIPE_V>();
// compute odd bit elements
// y_odd = x_odd * cos - x_even * sin
Mul<T>(tmp_odd1, tmp_odd, cos_ub, _tile_len / 2);
Mul<T>(tmp_odd2, tmp_even, sin_ub, _tile_len / 2);
PipeBarrier<PIPE_V>();
Sub<T>(tmp_odd1, tmp_odd1, tmp_odd2, _tile_len / 2);
// compute even bit elements
// y_even = x_odd * sin + x_even * cos
Mul<T>(tmp_even1, tmp_odd, sin_ub, _tile_len / 2);
Mul<T>(tmp_even2, tmp_even, cos_ub, _tile_len / 2);
PipeBarrier<PIPE_V>();
Add<T>(tmp_even1, tmp_even1, tmp_even2, _tile_len / 2);
// combine odd and even bit elements
for (uint32_t j = 0; j < _tile_len / 2; j += 1) {
output_ub(j * 2) = tmp_odd1(j);
output_ub(j * 2 + 1) = tmp_even1(j);
}
_out_que.EnQue<T>(output_ub);
_in_que.FreeTensor(input_ub);
_sin_que.FreeTensor(sin_ub);
_cos_que.FreeTensor(cos_ub);
}
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::copyOut(size_t i) {
LocalTensor<T> output_ub = _out_que.DeQue<T>();
auto idy = i * _st_ynt + _block_idx * _st_ynh;
DataCopyExtParams params = {1, static_cast<uint32_t>(_tile_len * sizeof(T)), 0, 0, 0};
DataCopyPad(_y_gm[idy], output_ub, params);
_out_que.FreeTensor(output_ub);
}
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::process(size_t seq_len) {
for (size_t i = 0; i < seq_len; ++i) {
copyIn(i);
compute(i);
copyOut(i);
}
}
#define ROPE_KERNEL_INIT_ARGS y, x, pos, sin, cos, dhead, \
y_stride_seqlen, y_stride_nhead, \
x_stride_seqlen, x_stride_nhead
#define CASE_POSTYPE(POS_TYPE_ENUM, TYPE, POS_T) \
case POS_TYPE_ENUM: { \
RoPEKernel<TYPE, POS_T> op; \
op.init(ROPE_KERNEL_INIT_ARGS); \
op.process(seq_len); \
break; \
}
#define ROPE_KERNEL(TYPE, POSTYPE) \
switch (POSTYPE) { \
CASE_POSTYPE(INFINI_DTYPE_I8, TYPE, int8_t) \
CASE_POSTYPE(INFINI_DTYPE_I16, TYPE, int16_t) \
CASE_POSTYPE(INFINI_DTYPE_I32, TYPE, int32_t) \
CASE_POSTYPE(INFINI_DTYPE_I64, TYPE, int64_t) \
CASE_POSTYPE(INFINI_DTYPE_U8, TYPE, uint8_t) \
CASE_POSTYPE(INFINI_DTYPE_U16, TYPE, uint16_t) \
CASE_POSTYPE(INFINI_DTYPE_U32, TYPE, uint32_t) \
CASE_POSTYPE(INFINI_DTYPE_U64, TYPE, uint64_t) \
default: \
break; \
}
#define DEFINE_ROPE_KERNEL(KERNEL_NAME, TYPE) \
__global__ __aicore__ void KERNEL_NAME(GM_ADDR y, \
GM_ADDR x, \
GM_ADDR pos, \
GM_ADDR sin, \
GM_ADDR cos, \
size_t seq_len, \
size_t dhead, \
ptrdiff_t y_stride_seqlen, \
ptrdiff_t y_stride_nhead, \
ptrdiff_t x_stride_seqlen, \
ptrdiff_t x_stride_nhead, \
int32_t pos_type) { \
ROPE_KERNEL(TYPE, pos_type) \
}
DEFINE_ROPE_KERNEL(rope_kernel_float, float)
DEFINE_ROPE_KERNEL(rope_kernel_half, half)
#undef DEFINE_ROPE_KERNEL
#undef ROPE_KERNEL
#undef CASE_POSTYPE
#undef ROPE_KERNEL_INIT_ARGS
extern "C" infiniStatus_t rope_kernel_launch(
void *y,
void *x,
void *pos,
void *sin,
void *cos,
size_t seq_len,
size_t nhead,
size_t dhead,
infiniDtype_t dtype,
infiniDtype_t pos_type,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead,
void *stream) {
#define LAUNCH_ROPE_KERNEL(DTYPE_ENUM, KERNEL_NAME) \
case DTYPE_ENUM: \
KERNEL_NAME<<<nhead, nullptr, stream>>>(y, x, pos, sin, cos, \
seq_len, \
dhead, \
y_stride_seqlen, \
y_stride_nhead, \
x_stride_seqlen, \
x_stride_nhead, \
pos_type); \
return INFINI_STATUS_SUCCESS;
switch (dtype) {
LAUNCH_ROPE_KERNEL(INFINI_DTYPE_F16, rope_kernel_half)
LAUNCH_ROPE_KERNEL(INFINI_DTYPE_F32, rope_kernel_float)
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
...@@ -8,6 +8,9 @@ ...@@ -8,6 +8,9 @@
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
#include "cuda/rope_cuda.cuh" #include "cuda/rope_cuda.cuh"
#endif #endif
#ifdef ENABLE_ASCEND_API
#include "ascend/rope_ascend.h"
#endif
__C infiniStatus_t infiniopCreateRoPEDescriptor( __C infiniStatus_t infiniopCreateRoPEDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
...@@ -43,12 +46,8 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor( ...@@ -43,12 +46,8 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor(
pos_ids, sin_table, cos_table); pos_ids, sin_table, cos_table);
} }
#endif #endif
#ifdef ENABLE_ASCEND_NPU #ifdef ENABLE_ASCEND_API
case DevAscendNpu: { CREATE(INFINI_DEVICE_ASCEND, ascend);
return ascendCreateRoPEDescriptor((AscendHandle_t)handle,
(RoPEAscendDescriptor_t *)desc_ptr, t,
pos_ids, sin_table, cos_table);
}
#endif #endif
#ifdef ENABLE_METAX_GPU #ifdef ENABLE_METAX_GPU
case DevMetaxGpu: { case DevMetaxGpu: {
...@@ -90,10 +89,8 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc, ...@@ -90,10 +89,8 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc,
return bangGetRoPEWorkspaceSize((RoPEBangDescriptor_t)desc, size); return bangGetRoPEWorkspaceSize((RoPEBangDescriptor_t)desc, size);
} }
#endif #endif
#ifdef ENABLE_ASCEND_NPU #ifdef ENABLE_ASCEND_API
case DevAscendNpu: { GET(INFINI_DEVICE_ASCEND, ascend);
return ascendGetRoPEWorkspaceSize((RoPEAscendDescriptor_t)desc, size);
}
#endif #endif
#ifdef ENABLE_METAX_GPU #ifdef ENABLE_METAX_GPU
case DevMetaxGpu: { case DevMetaxGpu: {
...@@ -141,12 +138,8 @@ __C infiniStatus_t infiniopRoPE( ...@@ -141,12 +138,8 @@ __C infiniStatus_t infiniopRoPE(
t, pos_ids, sin_table, cos_table, stream); t, pos_ids, sin_table, cos_table, stream);
} }
#endif #endif
#ifdef ENABLE_ASCEND_NPU #ifdef ENABLE_ASCEND_API
case DevAscendNpu: { CALCULATE(INFINI_DEVICE_ASCEND, ascend);
return ascendRoPE((RoPEAscendDescriptor_t)desc, workspace,
workspace_size, t, pos_ids, sin_table, cos_table,
stream);
}
#endif #endif
#ifdef ENABLE_METAX_GPU #ifdef ENABLE_METAX_GPU
case DevMetaxGpu: { case DevMetaxGpu: {
...@@ -187,10 +180,8 @@ infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) { ...@@ -187,10 +180,8 @@ infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) {
return bangDestroyRoPEDescriptor((RoPEBangDescriptor_t)desc); return bangDestroyRoPEDescriptor((RoPEBangDescriptor_t)desc);
} }
#endif #endif
#ifdef ENABLE_ASCEND_NPU #ifdef ENABLE_ASCEND_API
case DevAscendNpu: { DELETE(INFINI_DEVICE_ASCEND, ascend);
return ascendDestroyRoPEDescriptor((RoPEAscendDescriptor_t)desc);
}
#endif #endif
#ifdef ENABLE_METAX_GPU #ifdef ENABLE_METAX_GPU
case DevMetaxGpu: { case DevMetaxGpu: {
......
...@@ -26,12 +26,6 @@ infiniStatus_t Descriptor::create(infiniopHandle_t handle, Descriptor **desc_ptr ...@@ -26,12 +26,6 @@ infiniStatus_t Descriptor::create(infiniopHandle_t handle, Descriptor **desc_ptr
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
extern "C" infiniStatus_t swiglu_kernel_launch(
void *c, void *a, void *b,
infiniDtype_t dtype, size_t batch, size_t seq, size_t hd,
ptrdiff_t stride_batch_c, ptrdiff_t stride_batch_a, ptrdiff_t stride_batch_b,
ptrdiff_t stride_seq_c, ptrdiff_t stride_seq_a, ptrdiff_t stride_seq_b, void *stream);
infiniStatus_t Descriptor::calculate(void *workspace, infiniStatus_t Descriptor::calculate(void *workspace,
size_t workspace_size, size_t workspace_size,
void *c, void *c,
......
...@@ -69,5 +69,11 @@ public: ...@@ -69,5 +69,11 @@ public:
void *stream) const; void *stream) const;
}; };
extern "C" infiniStatus_t swiglu_kernel_launch(
void *c, void *a, void *b,
infiniDtype_t dtype, size_t batch, size_t seq, size_t hd,
ptrdiff_t stride_batch_c, ptrdiff_t stride_batch_a, ptrdiff_t stride_batch_b,
ptrdiff_t stride_seq_c, ptrdiff_t stride_seq_a, ptrdiff_t stride_seq_b, void *stream);
} // namespace op::swiglu::ascend } // namespace op::swiglu::ascend
#endif // __ACLNN_SWIGLU_H__ #endif // __ACLNN_SWIGLU_H__
...@@ -6,15 +6,20 @@ template <typename T> ...@@ -6,15 +6,20 @@ template <typename T>
class SwigluKernel { class SwigluKernel {
public: public:
__aicore__ inline SwigluKernel() {} __aicore__ inline SwigluKernel() {}
__aicore__ inline void init(GM_ADDR c, GM_ADDR a, GM_ADDR b, int64_t batch_, int64_t seq, int64_t hd, __aicore__ inline void init(GM_ADDR c, GM_ADDR a, GM_ADDR b,
int64_t stride_batch_c, int64_t stride_batch_a, int64_t stride_batch_b, size_t batch_, size_t seq, size_t hd,
int64_t stride_seq_c, int64_t stride_seq_a, int64_t stride_seq_b); ptrdiff_t stride_batch_c,
ptrdiff_t stride_batch_a,
ptrdiff_t stride_batch_b,
ptrdiff_t stride_seq_c,
ptrdiff_t stride_seq_a,
ptrdiff_t stride_seq_b);
__aicore__ inline void process(); __aicore__ inline void process();
private: private:
__aicore__ inline void copyIn(int64_t i); __aicore__ inline void copyIn(size_t i);
__aicore__ inline void compute(int64_t i); __aicore__ inline void compute(size_t i);
__aicore__ inline void copyOut(int64_t i); __aicore__ inline void copyOut(size_t i);
private: private:
GlobalTensor<T> _c_gm, _a_gm, _b_gm; GlobalTensor<T> _c_gm, _a_gm, _b_gm;
...@@ -23,16 +28,21 @@ private: ...@@ -23,16 +28,21 @@ private:
TPipe _pipe; TPipe _pipe;
float _beta_value = 1.0f; float _beta_value = 1.0f;
int64_t _block_idx, _tile_len, _copy_len, size_t _block_idx, _tile_len, _copy_len,
_batch, _seq_len, _hidden_size, _batch, _seq_len, _hidden_size,
_stride_seq_a, _stride_seq_b, _stride_seq_c; _stride_seq_a, _stride_seq_b, _stride_seq_c;
int64_t _stride_batch_a = 1, _stride_batch_b = 1, _stride_batch_c = 1; int64_t _stride_batch_a = 1, _stride_batch_b = 1, _stride_batch_c = 1;
}; };
template <typename T> template <typename T>
__aicore__ inline void SwigluKernel<T>::init(GM_ADDR c, GM_ADDR a, GM_ADDR b, int64_t batch_, int64_t seq, int64_t hd, __aicore__ inline void SwigluKernel<T>::init(GM_ADDR c, GM_ADDR a, GM_ADDR b,
int64_t stride_batch_c, int64_t stride_batch_a, int64_t stride_batch_b, size_t batch_, size_t seq, size_t hd,
int64_t stride_seq_c, int64_t stride_seq_a, int64_t stride_seq_b) { ptrdiff_t stride_batch_c,
ptrdiff_t stride_batch_a,
ptrdiff_t stride_batch_b,
ptrdiff_t stride_seq_c,
ptrdiff_t stride_seq_a,
ptrdiff_t stride_seq_b) {
// Init Shape & StrideVariables // Init Shape & StrideVariables
_batch = batch_; _batch = batch_;
_seq_len = seq; _seq_len = seq;
...@@ -46,7 +56,7 @@ __aicore__ inline void SwigluKernel<T>::init(GM_ADDR c, GM_ADDR a, GM_ADDR b, in ...@@ -46,7 +56,7 @@ __aicore__ inline void SwigluKernel<T>::init(GM_ADDR c, GM_ADDR a, GM_ADDR b, in
_block_idx = GetBlockIdx(); _block_idx = GetBlockIdx();
_tile_len = _block_idx < (_hidden_size % BLOCK_NUM) ? (_hidden_size / BLOCK_NUM) + 1 : (_hidden_size / BLOCK_NUM); _tile_len = _block_idx < (_hidden_size % BLOCK_NUM) ? (_hidden_size / BLOCK_NUM) + 1 : (_hidden_size / BLOCK_NUM);
_copy_len = (_tile_len * sizeof(T)) % BYTE_ALIGN == 0 ? _tile_len : (_tile_len * sizeof(T) + (BYTE_ALIGN - _tile_len * sizeof(T) % BYTE_ALIGN)) / sizeof(T); _copy_len = alignTileLen<T>(_tile_len, BYTE_ALIGN);
// Set global tensor // Set global tensor
_a_gm.SetGlobalBuffer((__gm__ T *)a); _a_gm.SetGlobalBuffer((__gm__ T *)a);
...@@ -60,7 +70,7 @@ __aicore__ inline void SwigluKernel<T>::init(GM_ADDR c, GM_ADDR a, GM_ADDR b, in ...@@ -60,7 +70,7 @@ __aicore__ inline void SwigluKernel<T>::init(GM_ADDR c, GM_ADDR a, GM_ADDR b, in
} }
template <typename T> template <typename T>
__aicore__ inline void SwigluKernel<T>::copyIn(int64_t i) { __aicore__ inline void SwigluKernel<T>::copyIn(size_t i) {
// Alloc tensor from queue memory // Alloc tensor from queue memory
LocalTensor<T> aLocal = _in_queue_a.AllocTensor<T>(); LocalTensor<T> aLocal = _in_queue_a.AllocTensor<T>();
LocalTensor<T> bLocal = _in_queue_b.AllocTensor<T>(); LocalTensor<T> bLocal = _in_queue_b.AllocTensor<T>();
...@@ -68,8 +78,8 @@ __aicore__ inline void SwigluKernel<T>::copyIn(int64_t i) { ...@@ -68,8 +78,8 @@ __aicore__ inline void SwigluKernel<T>::copyIn(int64_t i) {
auto batch_idx = _batch == 1 ? 0 : i / _seq_len; auto batch_idx = _batch == 1 ? 0 : i / _seq_len;
auto seq_idx = _batch == 1 ? i : i % _seq_len; auto seq_idx = _batch == 1 ? i : i % _seq_len;
int64_t idxa = batch_idx * _stride_batch_a + seq_idx * _stride_seq_a + _block_idx * _tile_len; ptrdiff_t idxa = batch_idx * _stride_batch_a + seq_idx * _stride_seq_a + _block_idx * _tile_len;
int64_t idxb = batch_idx * _stride_batch_b + seq_idx * _stride_seq_b + _block_idx * _tile_len; ptrdiff_t idxb = batch_idx * _stride_batch_b + seq_idx * _stride_seq_b + _block_idx * _tile_len;
// Copy process_th tile from global tensor to local tensor // Copy process_th tile from global tensor to local tensor
DataCopy(aLocal, _a_gm[idxa], _copy_len); DataCopy(aLocal, _a_gm[idxa], _copy_len);
DataCopy(bLocal, _b_gm[idxb], _copy_len); DataCopy(bLocal, _b_gm[idxb], _copy_len);
...@@ -80,7 +90,7 @@ __aicore__ inline void SwigluKernel<T>::copyIn(int64_t i) { ...@@ -80,7 +90,7 @@ __aicore__ inline void SwigluKernel<T>::copyIn(int64_t i) {
} }
template <typename T> template <typename T>
__aicore__ inline void SwigluKernel<T>::compute(int64_t i) { __aicore__ inline void SwigluKernel<T>::compute(size_t i) {
// Deque input tensors from VECIN queue // Deque input tensors from VECIN queue
LocalTensor<T> aLocal = _in_queue_a.DeQue<T>(); LocalTensor<T> aLocal = _in_queue_a.DeQue<T>();
LocalTensor<T> bLocal = _in_queue_b.DeQue<T>(); LocalTensor<T> bLocal = _in_queue_b.DeQue<T>();
...@@ -94,12 +104,12 @@ __aicore__ inline void SwigluKernel<T>::compute(int64_t i) { ...@@ -94,12 +104,12 @@ __aicore__ inline void SwigluKernel<T>::compute(int64_t i) {
} }
template <typename T> template <typename T>
__aicore__ inline void SwigluKernel<T>::copyOut(int64_t i) { __aicore__ inline void SwigluKernel<T>::copyOut(size_t i) {
// Deque output tensor from VECOUT queue // Deque output tensor from VECOUT queue
LocalTensor<T> cLocal = _out_queue_c.DeQue<T>(); LocalTensor<T> cLocal = _out_queue_c.DeQue<T>();
auto batch_idx = _batch == 1 ? 0 : i / _seq_len; auto batch_idx = _batch == 1 ? 0 : i / _seq_len;
auto seq_idx = _batch == 1 ? i : i % _seq_len; auto seq_idx = _batch == 1 ? i : i % _seq_len;
int64_t idxc = batch_idx * _stride_batch_c + seq_idx * _stride_seq_c + _block_idx * _tile_len; ptrdiff_t idxc = batch_idx * _stride_batch_c + seq_idx * _stride_seq_c + _block_idx * _tile_len;
// Copy progress_th tile from local tensor to global tensor // Copy progress_th tile from local tensor to global tensor
if (_tile_len * sizeof(T) % BYTE_ALIGN != 0) { if (_tile_len * sizeof(T) % BYTE_ALIGN != 0) {
DataCopyExtParams dcep = {1, static_cast<uint32_t>(_tile_len * sizeof(T)), 0, 0, 0}; DataCopyExtParams dcep = {1, static_cast<uint32_t>(_tile_len * sizeof(T)), 0, 0, 0};
...@@ -113,28 +123,28 @@ __aicore__ inline void SwigluKernel<T>::copyOut(int64_t i) { ...@@ -113,28 +123,28 @@ __aicore__ inline void SwigluKernel<T>::copyOut(int64_t i) {
template <typename T> template <typename T>
__aicore__ inline void SwigluKernel<T>::process() { __aicore__ inline void SwigluKernel<T>::process() {
for (int64_t i = 0; i < _batch * _seq_len; ++i) { for (size_t i = 0; i < _batch * _seq_len; ++i) {
copyIn(i); copyIn(i);
compute(i); compute(i);
copyOut(i); copyOut(i);
} }
} }
#define DEFINE_SWIGLU_KERNEL(KERNEL_NAME, TYPE) \ #define DEFINE_SWIGLU_KERNEL(KERNEL_NAME, TYPE) \
__global__ __aicore__ void KERNEL_NAME(GM_ADDR c, GM_ADDR a, GM_ADDR b, \ __global__ __aicore__ void KERNEL_NAME(GM_ADDR c, GM_ADDR a, GM_ADDR b, \
int64_t batch, int64_t seq, int64_t hd, \ size_t batch, size_t seq, size_t hd, \
int64_t stride_batch_c, \ ptrdiff_t stride_batch_c, \
int64_t stride_batch_a, \ ptrdiff_t stride_batch_a, \
int64_t stride_batch_b, \ ptrdiff_t stride_batch_b, \
int64_t stride_seq_c, \ ptrdiff_t stride_seq_c, \
int64_t stride_seq_a, \ ptrdiff_t stride_seq_a, \
int64_t stride_seq_b) { \ ptrdiff_t stride_seq_b) { \
SwigluKernel<TYPE> op; \ SwigluKernel<TYPE> op; \
op.init(c, a, b, \ op.init(c, a, b, \
batch, seq, hd, \ batch, seq, hd, \
stride_batch_c, stride_batch_a, stride_batch_b, \ stride_batch_c, stride_batch_a, stride_batch_b, \
stride_seq_c, stride_seq_a, stride_seq_b); \ stride_seq_c, stride_seq_a, stride_seq_b); \
op.process(); \ op.process(); \
} }
DEFINE_SWIGLU_KERNEL(swiglu_kernel_half, half) DEFINE_SWIGLU_KERNEL(swiglu_kernel_half, half)
...@@ -152,9 +162,9 @@ extern "C" infiniStatus_t swiglu_kernel_launch( ...@@ -152,9 +162,9 @@ extern "C" infiniStatus_t swiglu_kernel_launch(
case DTYPE_ENUM: \ case DTYPE_ENUM: \
KERNEL_NAME<<<BLOCK_NUM, nullptr, stream>>>( \ KERNEL_NAME<<<BLOCK_NUM, nullptr, stream>>>( \
c, a, b, \ c, a, b, \
static_cast<int64_t>(batch), \ batch, \
static_cast<int64_t>(seq), \ seq, \
static_cast<int64_t>(hd), \ hd, \
stride_batch_c, stride_batch_a, stride_batch_b, \ stride_batch_c, stride_batch_a, stride_batch_b, \
stride_seq_c, stride_seq_a, stride_seq_b); \ stride_seq_c, stride_seq_a, stride_seq_b); \
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
......
...@@ -189,6 +189,9 @@ def test( ...@@ -189,6 +189,9 @@ def test(
) )
lib_rope() lib_rope()
if sync is not None:
sync()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG: if DEBUG:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment