"tilelang/original/docs/programming_guides/overview.md" did not exist on "f5bc26c2295e334b3a8ce4cce8a7ba4b7927c736"
Commit 30bf79f1 authored by zhangyue's avatar zhangyue
Browse files

restruct input dtype

parent 2c0e9a6e
...@@ -25,7 +25,7 @@ include_directories( ...@@ -25,7 +25,7 @@ include_directories(
ascendc_library(ascend_kernels STATIC ascendc_library(ascend_kernels STATIC
../../ops/swiglu/ascend/swiglu_ascend_kernel.cpp ../../ops/swiglu/ascend/swiglu_ascend_kernel.cpp
../../ops/rope/ascend/rope_kernel.cpp ../../ops/rope/ascend/rope_ascend_kernel.cpp
# ../../ops/random_sample/ascend/random_sample_kernel.cpp # ../../ops/random_sample/ascend/random_sample_kernel.cpp
) )
...@@ -4,8 +4,17 @@ ...@@ -4,8 +4,17 @@
#include "../../../../include/infinicore.h" #include "../../../../include/infinicore.h"
#include "kernel_operator.h" #include "kernel_operator.h"
constexpr int32_t BLOCK_NUM = 8; constexpr size_t BLOCK_NUM = 8;
constexpr int32_t BUFFER_NUM = 2; constexpr size_t BUFFER_NUM = 2;
constexpr int32_t BYTE_ALIGN = 32; constexpr size_t BYTE_ALIGN = 32;
template <typename T>
__aicore__ inline size_t alignTileLen(size_t tile_len, size_t byte_align) {
size_t bytes = tile_len * sizeof(T);
size_t aligned_bytes = (bytes % byte_align == 0)
? bytes
: (bytes + (byte_align - bytes % byte_align));
return aligned_bytes / sizeof(T);
}
#endif #endif
#ifndef __ACLNN_ROPE_H__
#define __ACLNN_ROPE_H__
#include "../rope.h"
DESCRIPTOR(ascend)
#endif // __ACLNN_ROPE_H__
#include "rope_aclnn.h" #include "rope_ascend.h"
#include "../../../devices/ascend/common_ascend.h" #include "../../../devices/ascend/common_ascend.h"
namespace op::rope::ascend { namespace op::rope::ascend {
...@@ -23,23 +23,6 @@ infiniStatus_t Descriptor::create( ...@@ -23,23 +23,6 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
extern "C" infiniStatus_t rope_kernel_launch(
void *y,
void *x,
void *pos,
void *sin,
void *cos,
int32_t seq_len,
int32_t nhead,
int32_t dhead,
int32_t data_type,
int32_t pos_type,
int32_t y_stride_seqlen,
int32_t y_stride_nhead,
int32_t x_stride_seqlen,
int32_t x_stride_nhead,
void *stream);
infiniStatus_t Descriptor::calculate( infiniStatus_t Descriptor::calculate(
void *workspace, void *workspace,
size_t workspace_size, size_t workspace_size,
...@@ -50,15 +33,18 @@ infiniStatus_t Descriptor::calculate( ...@@ -50,15 +33,18 @@ infiniStatus_t Descriptor::calculate(
const void *cos_table, const void *cos_table,
void *stream) const { void *stream) const {
CHECK_DTYPE(_info.data_type, INFINI_DTYPE_F32, INFINI_DTYPE_F16); CHECK_DTYPE(_info.data_type, INFINI_DTYPE_F32, INFINI_DTYPE_F16);
int32_t seq_len = _info.seqlen;
int32_t nhead = _info.nhead; auto data_type = _info.data_type;
int32_t dhead = _info.dhead; auto pos_type = _info.pos_type;
int32_t data_type = _info.data_type; auto seq_len = _info.seqlen;
int32_t pos_type = _info.pos_type; auto nhead = _info.nhead;
int32_t y_stride_seqlen = _info.y_stride_seqlen; auto dhead = _info.dhead;
int32_t y_stride_nhead = _info.y_stride_nhead;
int32_t x_stride_seqlen = _info.x_stride_seqlen; auto y_stride_seqlen = _info.y_stride_seqlen;
int32_t x_stride_nhead = _info.x_stride_nhead; auto y_stride_nhead = _info.y_stride_nhead;
auto x_stride_seqlen = _info.x_stride_seqlen;
auto x_stride_nhead = _info.x_stride_nhead;
return rope_kernel_launch(y, (void *)x, (void *)pos_ids, (void *)sin_table, (void *)cos_table, seq_len, nhead, dhead, data_type, pos_type, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead, stream); return rope_kernel_launch(y, (void *)x, (void *)pos_ids, (void *)sin_table, (void *)cos_table, seq_len, nhead, dhead, data_type, pos_type, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead, stream);
} }
} // namespace op::rope::ascend } // namespace op::rope::ascend
#ifndef __ACLNN_ROPE_H__
#define __ACLNN_ROPE_H__
#include "../rope.h"
extern "C" infiniStatus_t rope_kernel_launch(
void *y,
void *x,
void *pos,
void *sin,
void *cos,
size_t seq_len,
size_t nhead,
size_t dhead,
infiniDtype_t data_type,
infiniDtype_t pos_type,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead,
void *stream);
DESCRIPTOR(ascend)
#endif // __ACLNN_ROPE_H__
#include "../../../../../include/infinicore.h" #include "../../../devices/ascend/ascend_kernel_common.h"
#include "kernel_operator.h"
constexpr int32_t BLOCK_NUM = 8;
constexpr int32_t BUFFER_NUM = 2;
constexpr int32_t BYTE_ALIGN = 32;
using namespace AscendC; using namespace AscendC;
...@@ -17,17 +12,23 @@ public: ...@@ -17,17 +12,23 @@ public:
// y output tensor // y output tensor
// tensor shape [nt, nh, dh] // tensor shape [nt, nh, dh]
// make block_num = nh, tile_len = dh // make block_num = nh, tile_len = dh
__aicore__ inline void Init(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM_ADDR sin, GM_ADDR cos, __aicore__ inline void init(GM_ADDR y,
int32_t dh, GM_ADDR x,
int32_t st_ynt, int32_t st_ynh, GM_ADDR pos,
int32_t st_xnt, int32_t st_xnh); GM_ADDR sin,
__aicore__ inline void Process(int32_t seq_len); GM_ADDR cos,
size_t dh,
ptrdiff_t st_ynt,
ptrdiff_t st_ynh,
ptrdiff_t st_xnt,
ptrdiff_t st_xnh);
__aicore__ inline void process(size_t seq_len);
private: private:
// Copy a tile into UB // Copy a tile into UB
__aicore__ inline void CopyIn(int32_t i); __aicore__ inline void copyIn(size_t i);
__aicore__ inline void Compute(int32_t i); __aicore__ inline void compute(size_t i);
__aicore__ inline void CopyOut(int32_t i); __aicore__ inline void copyOut(size_t i);
private: private:
TPipe pipe; TPipe pipe;
...@@ -47,31 +48,37 @@ private: ...@@ -47,31 +48,37 @@ private:
GlobalTensor<T> sinGm; GlobalTensor<T> sinGm;
GlobalTensor<T> cosGm; GlobalTensor<T> cosGm;
uint32_t _block_idx; size_t _block_idx;
uint32_t _tile_len; size_t _tile_len;
uint32_t _copy_len; size_t _copy_len;
uint32_t _half_copy_len; size_t _half_copy_len;
// stridey[st_ynt_, st_ynh_, 1] // stridey[_st_ynt, _st_ynh, 1]
int32_t st_ynt_; ptrdiff_t _st_ynt;
int32_t st_ynh_; ptrdiff_t _st_ynh;
// stridex[st_xnt_, st_xnh_, 1] // stridex[_st_xnt, _st_xnh, 1]
int32_t st_xnt_; ptrdiff_t _st_xnt;
int32_t st_xnh_; ptrdiff_t _st_xnh;
}; };
template <typename T, typename U> template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::Init(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM_ADDR sin, GM_ADDR cos, __aicore__ inline void RoPEKernel<T, U>::init(GM_ADDR y,
int32_t dh, GM_ADDR x,
int32_t st_ynt, int32_t st_ynh, GM_ADDR pos,
int32_t st_xnt, int32_t st_xnh) { GM_ADDR sin,
GM_ADDR cos,
size_t dh,
ptrdiff_t st_ynt,
ptrdiff_t st_ynh,
ptrdiff_t st_xnt,
ptrdiff_t st_xnh) {
this->_tile_len = dh; this->_tile_len = dh;
this->st_ynt_ = st_ynt; this->_st_ynt = st_ynt;
this->st_ynh_ = st_ynh; this->_st_ynh = st_ynh;
this->st_xnt_ = st_xnt; this->_st_xnt = st_xnt;
this->st_xnh_ = st_xnh; this->_st_xnh = st_xnh;
_copy_len = (_tile_len * sizeof(T)) % BYTE_ALIGN == 0 ? _tile_len : (_tile_len * sizeof(T) + (BYTE_ALIGN - _tile_len * sizeof(T) % BYTE_ALIGN)) / sizeof(T); _copy_len = alignTileLen<T>(dh, BYTE_ALIGN);
_half_copy_len = (_tile_len / 2 * sizeof(T)) % BYTE_ALIGN == 0 ? _tile_len / 2 : (_tile_len / 2 * sizeof(T) + (BYTE_ALIGN - _tile_len / 2 * sizeof(T) % BYTE_ALIGN)) / sizeof(T); _half_copy_len = alignTileLen<T>(dh, BYTE_ALIGN);
_block_idx = GetBlockIdx(); _block_idx = GetBlockIdx();
...@@ -96,12 +103,12 @@ __aicore__ inline void RoPEKernel<T, U>::Init(GM_ADDR y, GM_ADDR x, GM_ADDR pos, ...@@ -96,12 +103,12 @@ __aicore__ inline void RoPEKernel<T, U>::Init(GM_ADDR y, GM_ADDR x, GM_ADDR pos,
} }
template <typename T, typename U> template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::CopyIn(int32_t i) { __aicore__ inline void RoPEKernel<T, U>::copyIn(size_t i) {
LocalTensor<T> inputUb = inQue.AllocTensor<T>(); LocalTensor<T> inputUb = inQue.AllocTensor<T>();
LocalTensor<T> sinUb = sinQue.AllocTensor<T>(); LocalTensor<T> sinUb = sinQue.AllocTensor<T>();
LocalTensor<T> cosUb = cosQue.AllocTensor<T>(); LocalTensor<T> cosUb = cosQue.AllocTensor<T>();
// Get idx of current tile in total input // Get idx of current tile in total input
auto idx = i * st_xnt_ + _block_idx * st_xnh_; auto idx = i * _st_xnt + _block_idx * _st_xnh;
// Copy tile current tile into UB // Copy tile current tile into UB
DataCopy(inputUb, xGm[idx], _copy_len); DataCopy(inputUb, xGm[idx], _copy_len);
// Copy sin cos tile // Copy sin cos tile
...@@ -115,7 +122,7 @@ __aicore__ inline void RoPEKernel<T, U>::CopyIn(int32_t i) { ...@@ -115,7 +122,7 @@ __aicore__ inline void RoPEKernel<T, U>::CopyIn(int32_t i) {
} }
template <typename T, typename U> template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::Compute(int32_t i) { __aicore__ inline void RoPEKernel<T, U>::compute(size_t i) {
LocalTensor<T> inputUb = inQue.DeQue<T>(); LocalTensor<T> inputUb = inQue.DeQue<T>();
LocalTensor<T> sinUb = sinQue.DeQue<T>(); LocalTensor<T> sinUb = sinQue.DeQue<T>();
LocalTensor<T> cosUb = cosQue.DeQue<T>(); LocalTensor<T> cosUb = cosQue.DeQue<T>();
...@@ -167,118 +174,130 @@ __aicore__ inline void RoPEKernel<T, U>::Compute(int32_t i) { ...@@ -167,118 +174,130 @@ __aicore__ inline void RoPEKernel<T, U>::Compute(int32_t i) {
} }
template <typename T, typename U> template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::CopyOut(int32_t i) { __aicore__ inline void RoPEKernel<T, U>::copyOut(size_t i) {
LocalTensor<T> outputUb = outQue.DeQue<T>(); LocalTensor<T> outputUb = outQue.DeQue<T>();
auto idy = i * st_ynt_ + _block_idx * st_ynh_; auto idy = i * _st_ynt + _block_idx * _st_ynh;
DataCopyExtParams params = {1, static_cast<uint32_t>(_tile_len * sizeof(T)), 0, 0, 0}; DataCopyExtParams params = {1, static_cast<uint32_t>(_tile_len * sizeof(T)), 0, 0, 0};
DataCopyPad(yGm[idy], outputUb, params); DataCopyPad(yGm[idy], outputUb, params);
outQue.FreeTensor(outputUb); outQue.FreeTensor(outputUb);
} }
template <typename T, typename U> template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::Process(int32_t nt) { __aicore__ inline void RoPEKernel<T, U>::process(size_t seq_len) {
for (int32_t i = 0; i < nt; ++i) { for (size_t i = 0; i < seq_len; ++i) {
CopyIn(i); copyIn(i);
Compute(i); compute(i);
CopyOut(i); copyOut(i);
} }
} }
#define ROPE_KERNEL(TYPE, POSTYPE) \ #define ROPE_KERNEL(TYPE, POSTYPE) \
switch (POSTYPE) { \ switch (POSTYPE) { \
case 3: { \ case INFINI_DTYPE_I8: { \
RoPEKernel<TYPE, int8_t> op; \ RoPEKernel<TYPE, int8_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \ op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \ op.process(seq_len); \
break; \ break; \
} \ } \
case 4: { \ case INFINI_DTYPE_I16: { \
RoPEKernel<TYPE, int16_t> op; \ RoPEKernel<TYPE, int16_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \ op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \ op.process(seq_len); \
break; \ break; \
} \ } \
case 5: { \ case INFINI_DTYPE_I32: { \
RoPEKernel<TYPE, int32_t> op; \ RoPEKernel<TYPE, int32_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \ op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \ op.process(seq_len); \
break; \ break; \
} \ } \
case 6: { \ case INFINI_DTYPE_I64: { \
RoPEKernel<TYPE, int64_t> op; \ RoPEKernel<TYPE, int64_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \ op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \ op.process(seq_len); \
break; \ break; \
} \ } \
case 7: { \ case INFINI_DTYPE_U8: { \
RoPEKernel<TYPE, uint8_t> op; \ RoPEKernel<TYPE, uint8_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \ op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \ op.process(seq_len); \
break; \ break; \
} \ } \
case 8: { \ case INFINI_DTYPE_U16: { \
RoPEKernel<TYPE, uint16_t> op; \ RoPEKernel<TYPE, uint16_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \ op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \ op.process(seq_len); \
break; \ break; \
} \ } \
case 9: { \ case INFINI_DTYPE_U32: { \
RoPEKernel<TYPE, uint32_t> op; \ RoPEKernel<TYPE, uint32_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \ op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \ op.process(seq_len); \
break; \ break; \
} \ } \
case 10: { \ case INFINI_DTYPE_U64: { \
RoPEKernel<TYPE, uint64_t> op; \ RoPEKernel<TYPE, uint64_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \ op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \ op.process(seq_len); \
break; \ break; \
} \ } \
} }
__global__ __aicore__ void rope_f16_kernel(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM_ADDR sin, GM_ADDR cos, #define DEFINE_ROPE_KERNEL(KERNEL_NAME, TYPE) \
int32_t seq_len, int32_t dhead, __global__ __aicore__ void KERNEL_NAME(GM_ADDR y, \
int32_t y_stride_seqlen, int32_t y_stride_nhead, GM_ADDR x, \
int32_t x_stride_seqlen, int32_t x_stride_nhead, GM_ADDR pos, \
int32_t pos_type){ GM_ADDR sin, \
ROPE_KERNEL(half, pos_type) GM_ADDR cos, \
size_t seq_len, \
size_t dhead, \
ptrdiff_t y_stride_seqlen, \
ptrdiff_t y_stride_nhead, \
ptrdiff_t x_stride_seqlen, \
ptrdiff_t x_stride_nhead, \
int32_t pos_type) { \
ROPE_KERNEL(TYPE, pos_type) \
}
} DEFINE_ROPE_KERNEL(rope_kernel_float, float)
DEFINE_ROPE_KERNEL(rope_kernel_half, half)
__global__ __aicore__ void rope_f32_kernel(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM_ADDR sin, GM_ADDR cos, #undef DEFINE_ROPE_KERNEL
int32_t seq_len, int32_t dhead,
int32_t y_stride_seqlen, int32_t y_stride_nhead,
int32_t x_stride_seqlen, int32_t x_stride_nhead,
int32_t pos_type) {
ROPE_KERNEL(float, pos_type)
}
extern "C" infiniStatus_t rope_kernel_launch(void *y, extern "C" infiniStatus_t rope_kernel_launch(
void *x, void *y,
void *pos, void *x,
void *sin, void *pos,
void *cos, void *sin,
int32_t seq_len, void *cos,
int32_t nhead, size_t seq_len,
int32_t dhead, size_t nhead,
int32_t data_type, size_t dhead,
int32_t pos_type, infiniDtype_t dtype,
int32_t y_stride_seqlen, infiniDtype_t pos_type,
int32_t y_stride_nhead, ptrdiff_t y_stride_seqlen,
int32_t x_stride_seqlen, ptrdiff_t y_stride_nhead,
int32_t x_stride_nhead, ptrdiff_t x_stride_seqlen,
void *stream) { ptrdiff_t x_stride_nhead,
switch (data_type) { void *stream) {
case 12: // float16
rope_f16_kernel<<<nhead, nullptr, stream>>>(y, x, pos, sin, cos, seq_len, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead, pos_type); #define LAUNCH_ROPE_KERNEL(DTYPE_ENUM, KERNEL_NAME) \
break; case DTYPE_ENUM: \
case 13: // float32 KERNEL_NAME<<<nhead, nullptr, stream>>>(y, x, pos, sin, cos, \
rope_f32_kernel<<<nhead, nullptr, stream>>>(y, x, pos, sin, cos, seq_len, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead, pos_type); seq_len, \
break; dhead, \
y_stride_seqlen, \
y_stride_nhead, \
x_stride_seqlen, \
x_stride_nhead, \
pos_type); \
return INFINI_STATUS_SUCCESS;
switch (dtype) {
LAUNCH_ROPE_KERNEL(INFINI_DTYPE_F16, rope_kernel_half)
LAUNCH_ROPE_KERNEL(INFINI_DTYPE_F32, rope_kernel_float)
default: default:
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
return INFINI_STATUS_SUCCESS;
} }
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "cuda/rope_cuda.cuh" #include "cuda/rope_cuda.cuh"
#endif #endif
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
#include "ascend/rope_aclnn.h" #include "ascend/rope_ascend.h"
#endif #endif
__C infiniStatus_t infiniopCreateRoPEDescriptor( __C infiniStatus_t infiniopCreateRoPEDescriptor(
......
...@@ -26,12 +26,6 @@ infiniStatus_t Descriptor::create(infiniopHandle_t handle, Descriptor **desc_ptr ...@@ -26,12 +26,6 @@ infiniStatus_t Descriptor::create(infiniopHandle_t handle, Descriptor **desc_ptr
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
extern "C" infiniStatus_t swiglu_kernel_launch(
void *c, void *a, void *b,
infiniDtype_t dtype, size_t batch, size_t seq, size_t hd,
ptrdiff_t stride_batch_c, ptrdiff_t stride_batch_a, ptrdiff_t stride_batch_b,
ptrdiff_t stride_seq_c, ptrdiff_t stride_seq_a, ptrdiff_t stride_seq_b, void *stream);
infiniStatus_t Descriptor::calculate(void *workspace, infiniStatus_t Descriptor::calculate(void *workspace,
size_t workspace_size, size_t workspace_size,
void *c, void *c,
......
...@@ -69,5 +69,11 @@ public: ...@@ -69,5 +69,11 @@ public:
void *stream) const; void *stream) const;
}; };
extern "C" infiniStatus_t swiglu_kernel_launch(
void *c, void *a, void *b,
infiniDtype_t dtype, size_t batch, size_t seq, size_t hd,
ptrdiff_t stride_batch_c, ptrdiff_t stride_batch_a, ptrdiff_t stride_batch_b,
ptrdiff_t stride_seq_c, ptrdiff_t stride_seq_a, ptrdiff_t stride_seq_b, void *stream);
} // namespace op::swiglu::ascend } // namespace op::swiglu::ascend
#endif // __ACLNN_SWIGLU_H__ #endif // __ACLNN_SWIGLU_H__
...@@ -6,15 +6,20 @@ template <typename T> ...@@ -6,15 +6,20 @@ template <typename T>
class SwigluKernel { class SwigluKernel {
public: public:
__aicore__ inline SwigluKernel() {} __aicore__ inline SwigluKernel() {}
__aicore__ inline void init(GM_ADDR c, GM_ADDR a, GM_ADDR b, int64_t batch_, int64_t seq, int64_t hd, __aicore__ inline void init(GM_ADDR c, GM_ADDR a, GM_ADDR b,
int64_t stride_batch_c, int64_t stride_batch_a, int64_t stride_batch_b, size_t batch_, size_t seq, size_t hd,
int64_t stride_seq_c, int64_t stride_seq_a, int64_t stride_seq_b); ptrdiff_t stride_batch_c,
ptrdiff_t stride_batch_a,
ptrdiff_t stride_batch_b,
ptrdiff_t stride_seq_c,
ptrdiff_t stride_seq_a,
ptrdiff_t stride_seq_b);
__aicore__ inline void process(); __aicore__ inline void process();
private: private:
__aicore__ inline void copyIn(int64_t i); __aicore__ inline void copyIn(size_t i);
__aicore__ inline void compute(int64_t i); __aicore__ inline void compute(size_t i);
__aicore__ inline void copyOut(int64_t i); __aicore__ inline void copyOut(size_t i);
private: private:
GlobalTensor<T> _c_gm, _a_gm, _b_gm; GlobalTensor<T> _c_gm, _a_gm, _b_gm;
...@@ -23,16 +28,21 @@ private: ...@@ -23,16 +28,21 @@ private:
TPipe _pipe; TPipe _pipe;
float _beta_value = 1.0f; float _beta_value = 1.0f;
int64_t _block_idx, _tile_len, _copy_len, size_t _block_idx, _tile_len, _copy_len,
_batch, _seq_len, _hidden_size, _batch, _seq_len, _hidden_size,
_stride_seq_a, _stride_seq_b, _stride_seq_c; _stride_seq_a, _stride_seq_b, _stride_seq_c;
int64_t _stride_batch_a = 1, _stride_batch_b = 1, _stride_batch_c = 1; int64_t _stride_batch_a = 1, _stride_batch_b = 1, _stride_batch_c = 1;
}; };
template <typename T> template <typename T>
__aicore__ inline void SwigluKernel<T>::init(GM_ADDR c, GM_ADDR a, GM_ADDR b, int64_t batch_, int64_t seq, int64_t hd, __aicore__ inline void SwigluKernel<T>::init(GM_ADDR c, GM_ADDR a, GM_ADDR b,
int64_t stride_batch_c, int64_t stride_batch_a, int64_t stride_batch_b, size_t batch_, size_t seq, size_t hd,
int64_t stride_seq_c, int64_t stride_seq_a, int64_t stride_seq_b) { ptrdiff_t stride_batch_c,
ptrdiff_t stride_batch_a,
ptrdiff_t stride_batch_b,
ptrdiff_t stride_seq_c,
ptrdiff_t stride_seq_a,
ptrdiff_t stride_seq_b) {
// Init Shape & StrideVariables // Init Shape & StrideVariables
_batch = batch_; _batch = batch_;
_seq_len = seq; _seq_len = seq;
...@@ -46,7 +56,7 @@ __aicore__ inline void SwigluKernel<T>::init(GM_ADDR c, GM_ADDR a, GM_ADDR b, in ...@@ -46,7 +56,7 @@ __aicore__ inline void SwigluKernel<T>::init(GM_ADDR c, GM_ADDR a, GM_ADDR b, in
_block_idx = GetBlockIdx(); _block_idx = GetBlockIdx();
_tile_len = _block_idx < (_hidden_size % BLOCK_NUM) ? (_hidden_size / BLOCK_NUM) + 1 : (_hidden_size / BLOCK_NUM); _tile_len = _block_idx < (_hidden_size % BLOCK_NUM) ? (_hidden_size / BLOCK_NUM) + 1 : (_hidden_size / BLOCK_NUM);
_copy_len = (_tile_len * sizeof(T)) % BYTE_ALIGN == 0 ? _tile_len : (_tile_len * sizeof(T) + (BYTE_ALIGN - _tile_len * sizeof(T) % BYTE_ALIGN)) / sizeof(T); _copy_len = alignTileLen<T>(_tile_len, BYTE_ALIGN);
// Set global tensor // Set global tensor
_a_gm.SetGlobalBuffer((__gm__ T *)a); _a_gm.SetGlobalBuffer((__gm__ T *)a);
...@@ -60,7 +70,7 @@ __aicore__ inline void SwigluKernel<T>::init(GM_ADDR c, GM_ADDR a, GM_ADDR b, in ...@@ -60,7 +70,7 @@ __aicore__ inline void SwigluKernel<T>::init(GM_ADDR c, GM_ADDR a, GM_ADDR b, in
} }
template <typename T> template <typename T>
__aicore__ inline void SwigluKernel<T>::copyIn(int64_t i) { __aicore__ inline void SwigluKernel<T>::copyIn(size_t i) {
// Alloc tensor from queue memory // Alloc tensor from queue memory
LocalTensor<T> aLocal = _in_queue_a.AllocTensor<T>(); LocalTensor<T> aLocal = _in_queue_a.AllocTensor<T>();
LocalTensor<T> bLocal = _in_queue_b.AllocTensor<T>(); LocalTensor<T> bLocal = _in_queue_b.AllocTensor<T>();
...@@ -68,8 +78,8 @@ __aicore__ inline void SwigluKernel<T>::copyIn(int64_t i) { ...@@ -68,8 +78,8 @@ __aicore__ inline void SwigluKernel<T>::copyIn(int64_t i) {
auto batch_idx = _batch == 1 ? 0 : i / _seq_len; auto batch_idx = _batch == 1 ? 0 : i / _seq_len;
auto seq_idx = _batch == 1 ? i : i % _seq_len; auto seq_idx = _batch == 1 ? i : i % _seq_len;
int64_t idxa = batch_idx * _stride_batch_a + seq_idx * _stride_seq_a + _block_idx * _tile_len; ptrdiff_t idxa = batch_idx * _stride_batch_a + seq_idx * _stride_seq_a + _block_idx * _tile_len;
int64_t idxb = batch_idx * _stride_batch_b + seq_idx * _stride_seq_b + _block_idx * _tile_len; ptrdiff_t idxb = batch_idx * _stride_batch_b + seq_idx * _stride_seq_b + _block_idx * _tile_len;
// Copy process_th tile from global tensor to local tensor // Copy process_th tile from global tensor to local tensor
DataCopy(aLocal, _a_gm[idxa], _copy_len); DataCopy(aLocal, _a_gm[idxa], _copy_len);
DataCopy(bLocal, _b_gm[idxb], _copy_len); DataCopy(bLocal, _b_gm[idxb], _copy_len);
...@@ -80,7 +90,7 @@ __aicore__ inline void SwigluKernel<T>::copyIn(int64_t i) { ...@@ -80,7 +90,7 @@ __aicore__ inline void SwigluKernel<T>::copyIn(int64_t i) {
} }
template <typename T> template <typename T>
__aicore__ inline void SwigluKernel<T>::compute(int64_t i) { __aicore__ inline void SwigluKernel<T>::compute(size_t i) {
// Deque input tensors from VECIN queue // Deque input tensors from VECIN queue
LocalTensor<T> aLocal = _in_queue_a.DeQue<T>(); LocalTensor<T> aLocal = _in_queue_a.DeQue<T>();
LocalTensor<T> bLocal = _in_queue_b.DeQue<T>(); LocalTensor<T> bLocal = _in_queue_b.DeQue<T>();
...@@ -94,12 +104,12 @@ __aicore__ inline void SwigluKernel<T>::compute(int64_t i) { ...@@ -94,12 +104,12 @@ __aicore__ inline void SwigluKernel<T>::compute(int64_t i) {
} }
template <typename T> template <typename T>
__aicore__ inline void SwigluKernel<T>::copyOut(int64_t i) { __aicore__ inline void SwigluKernel<T>::copyOut(size_t i) {
// Deque output tensor from VECOUT queue // Deque output tensor from VECOUT queue
LocalTensor<T> cLocal = _out_queue_c.DeQue<T>(); LocalTensor<T> cLocal = _out_queue_c.DeQue<T>();
auto batch_idx = _batch == 1 ? 0 : i / _seq_len; auto batch_idx = _batch == 1 ? 0 : i / _seq_len;
auto seq_idx = _batch == 1 ? i : i % _seq_len; auto seq_idx = _batch == 1 ? i : i % _seq_len;
int64_t idxc = batch_idx * _stride_batch_c + seq_idx * _stride_seq_c + _block_idx * _tile_len; ptrdiff_t idxc = batch_idx * _stride_batch_c + seq_idx * _stride_seq_c + _block_idx * _tile_len;
// Copy progress_th tile from local tensor to global tensor // Copy progress_th tile from local tensor to global tensor
if (_tile_len * sizeof(T) % BYTE_ALIGN != 0) { if (_tile_len * sizeof(T) % BYTE_ALIGN != 0) {
DataCopyExtParams dcep = {1, static_cast<uint32_t>(_tile_len * sizeof(T)), 0, 0, 0}; DataCopyExtParams dcep = {1, static_cast<uint32_t>(_tile_len * sizeof(T)), 0, 0, 0};
...@@ -113,28 +123,28 @@ __aicore__ inline void SwigluKernel<T>::copyOut(int64_t i) { ...@@ -113,28 +123,28 @@ __aicore__ inline void SwigluKernel<T>::copyOut(int64_t i) {
template <typename T> template <typename T>
__aicore__ inline void SwigluKernel<T>::process() { __aicore__ inline void SwigluKernel<T>::process() {
for (int64_t i = 0; i < _batch * _seq_len; ++i) { for (size_t i = 0; i < _batch * _seq_len; ++i) {
copyIn(i); copyIn(i);
compute(i); compute(i);
copyOut(i); copyOut(i);
} }
} }
#define DEFINE_SWIGLU_KERNEL(KERNEL_NAME, TYPE) \ #define DEFINE_SWIGLU_KERNEL(KERNEL_NAME, TYPE) \
__global__ __aicore__ void KERNEL_NAME(GM_ADDR c, GM_ADDR a, GM_ADDR b, \ __global__ __aicore__ void KERNEL_NAME(GM_ADDR c, GM_ADDR a, GM_ADDR b, \
int64_t batch, int64_t seq, int64_t hd, \ size_t batch, size_t seq, size_t hd, \
int64_t stride_batch_c, \ ptrdiff_t stride_batch_c, \
int64_t stride_batch_a, \ ptrdiff_t stride_batch_a, \
int64_t stride_batch_b, \ ptrdiff_t stride_batch_b, \
int64_t stride_seq_c, \ ptrdiff_t stride_seq_c, \
int64_t stride_seq_a, \ ptrdiff_t stride_seq_a, \
int64_t stride_seq_b) { \ ptrdiff_t stride_seq_b) { \
SwigluKernel<TYPE> op; \ SwigluKernel<TYPE> op; \
op.init(c, a, b, \ op.init(c, a, b, \
batch, seq, hd, \ batch, seq, hd, \
stride_batch_c, stride_batch_a, stride_batch_b, \ stride_batch_c, stride_batch_a, stride_batch_b, \
stride_seq_c, stride_seq_a, stride_seq_b); \ stride_seq_c, stride_seq_a, stride_seq_b); \
op.process(); \ op.process(); \
} }
DEFINE_SWIGLU_KERNEL(swiglu_kernel_half, half) DEFINE_SWIGLU_KERNEL(swiglu_kernel_half, half)
...@@ -152,9 +162,9 @@ extern "C" infiniStatus_t swiglu_kernel_launch( ...@@ -152,9 +162,9 @@ extern "C" infiniStatus_t swiglu_kernel_launch(
case DTYPE_ENUM: \ case DTYPE_ENUM: \
KERNEL_NAME<<<BLOCK_NUM, nullptr, stream>>>( \ KERNEL_NAME<<<BLOCK_NUM, nullptr, stream>>>( \
c, a, b, \ c, a, b, \
static_cast<int64_t>(batch), \ batch, \
static_cast<int64_t>(seq), \ seq, \
static_cast<int64_t>(hd), \ hd, \
stride_batch_c, stride_batch_a, stride_batch_b, \ stride_batch_c, stride_batch_a, stride_batch_b, \
stride_seq_c, stride_seq_a, stride_seq_b); \ stride_seq_c, stride_seq_a, stride_seq_b); \
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
......
...@@ -189,6 +189,9 @@ def test( ...@@ -189,6 +189,9 @@ def test(
) )
lib_rope() lib_rope()
if sync is not None:
sync()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG: if DEBUG:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment