Commit 30bf79f1 authored by zhangyue's avatar zhangyue
Browse files

restruct input dtype

parent 2c0e9a6e
......@@ -25,7 +25,7 @@ include_directories(
ascendc_library(ascend_kernels STATIC
../../ops/swiglu/ascend/swiglu_ascend_kernel.cpp
../../ops/rope/ascend/rope_kernel.cpp
../../ops/rope/ascend/rope_ascend_kernel.cpp
# ../../ops/random_sample/ascend/random_sample_kernel.cpp
)
......@@ -4,8 +4,17 @@
#include "../../../../include/infinicore.h"
#include "kernel_operator.h"
constexpr int32_t BLOCK_NUM = 8;
constexpr int32_t BUFFER_NUM = 2;
constexpr int32_t BYTE_ALIGN = 32;
constexpr size_t BLOCK_NUM = 8;
constexpr size_t BUFFER_NUM = 2;
constexpr size_t BYTE_ALIGN = 32;
template <typename T>
__aicore__ inline size_t alignTileLen(size_t tile_len, size_t byte_align) {
size_t bytes = tile_len * sizeof(T);
size_t aligned_bytes = (bytes % byte_align == 0)
? bytes
: (bytes + (byte_align - bytes % byte_align));
return aligned_bytes / sizeof(T);
}
#endif
#ifndef __ACLNN_ROPE_H__
#define __ACLNN_ROPE_H__
#include "../rope.h"
DESCRIPTOR(ascend)
#endif // __ACLNN_ROPE_H__
#include "rope_aclnn.h"
#include "rope_ascend.h"
#include "../../../devices/ascend/common_ascend.h"
namespace op::rope::ascend {
......@@ -23,23 +23,6 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_SUCCESS;
}
extern "C" infiniStatus_t rope_kernel_launch(
void *y,
void *x,
void *pos,
void *sin,
void *cos,
int32_t seq_len,
int32_t nhead,
int32_t dhead,
int32_t data_type,
int32_t pos_type,
int32_t y_stride_seqlen,
int32_t y_stride_nhead,
int32_t x_stride_seqlen,
int32_t x_stride_nhead,
void *stream);
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
......@@ -50,15 +33,18 @@ infiniStatus_t Descriptor::calculate(
const void *cos_table,
void *stream) const {
CHECK_DTYPE(_info.data_type, INFINI_DTYPE_F32, INFINI_DTYPE_F16);
int32_t seq_len = _info.seqlen;
int32_t nhead = _info.nhead;
int32_t dhead = _info.dhead;
int32_t data_type = _info.data_type;
int32_t pos_type = _info.pos_type;
int32_t y_stride_seqlen = _info.y_stride_seqlen;
int32_t y_stride_nhead = _info.y_stride_nhead;
int32_t x_stride_seqlen = _info.x_stride_seqlen;
int32_t x_stride_nhead = _info.x_stride_nhead;
auto data_type = _info.data_type;
auto pos_type = _info.pos_type;
auto seq_len = _info.seqlen;
auto nhead = _info.nhead;
auto dhead = _info.dhead;
auto y_stride_seqlen = _info.y_stride_seqlen;
auto y_stride_nhead = _info.y_stride_nhead;
auto x_stride_seqlen = _info.x_stride_seqlen;
auto x_stride_nhead = _info.x_stride_nhead;
return rope_kernel_launch(y, (void *)x, (void *)pos_ids, (void *)sin_table, (void *)cos_table, seq_len, nhead, dhead, data_type, pos_type, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead, stream);
}
} // namespace op::rope::ascend
#ifndef __ACLNN_ROPE_H__
#define __ACLNN_ROPE_H__
#include "../rope.h"
extern "C" infiniStatus_t rope_kernel_launch(
void *y,
void *x,
void *pos,
void *sin,
void *cos,
size_t seq_len,
size_t nhead,
size_t dhead,
infiniDtype_t data_type,
infiniDtype_t pos_type,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead,
void *stream);
DESCRIPTOR(ascend)
#endif // __ACLNN_ROPE_H__
#include "../../../../../include/infinicore.h"
#include "kernel_operator.h"
constexpr int32_t BLOCK_NUM = 8;
constexpr int32_t BUFFER_NUM = 2;
constexpr int32_t BYTE_ALIGN = 32;
#include "../../../devices/ascend/ascend_kernel_common.h"
using namespace AscendC;
......@@ -17,17 +12,23 @@ public:
// y output tensor
// tensor shape [nt, nh, dh]
// make block_num = nh, tile_len = dh
__aicore__ inline void Init(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM_ADDR sin, GM_ADDR cos,
int32_t dh,
int32_t st_ynt, int32_t st_ynh,
int32_t st_xnt, int32_t st_xnh);
__aicore__ inline void Process(int32_t seq_len);
__aicore__ inline void init(GM_ADDR y,
GM_ADDR x,
GM_ADDR pos,
GM_ADDR sin,
GM_ADDR cos,
size_t dh,
ptrdiff_t st_ynt,
ptrdiff_t st_ynh,
ptrdiff_t st_xnt,
ptrdiff_t st_xnh);
__aicore__ inline void process(size_t seq_len);
private:
// Copy a tile into UB
__aicore__ inline void CopyIn(int32_t i);
__aicore__ inline void Compute(int32_t i);
__aicore__ inline void CopyOut(int32_t i);
__aicore__ inline void copyIn(size_t i);
__aicore__ inline void compute(size_t i);
__aicore__ inline void copyOut(size_t i);
private:
TPipe pipe;
......@@ -47,31 +48,37 @@ private:
GlobalTensor<T> sinGm;
GlobalTensor<T> cosGm;
uint32_t _block_idx;
uint32_t _tile_len;
uint32_t _copy_len;
uint32_t _half_copy_len;
size_t _block_idx;
size_t _tile_len;
size_t _copy_len;
size_t _half_copy_len;
// stridey[st_ynt_, st_ynh_, 1]
int32_t st_ynt_;
int32_t st_ynh_;
// stridex[st_xnt_, st_xnh_, 1]
int32_t st_xnt_;
int32_t st_xnh_;
// stridey[_st_ynt, _st_ynh, 1]
ptrdiff_t _st_ynt;
ptrdiff_t _st_ynh;
// stridex[_st_xnt, _st_xnh, 1]
ptrdiff_t _st_xnt;
ptrdiff_t _st_xnh;
};
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::Init(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM_ADDR sin, GM_ADDR cos,
int32_t dh,
int32_t st_ynt, int32_t st_ynh,
int32_t st_xnt, int32_t st_xnh) {
__aicore__ inline void RoPEKernel<T, U>::init(GM_ADDR y,
GM_ADDR x,
GM_ADDR pos,
GM_ADDR sin,
GM_ADDR cos,
size_t dh,
ptrdiff_t st_ynt,
ptrdiff_t st_ynh,
ptrdiff_t st_xnt,
ptrdiff_t st_xnh) {
this->_tile_len = dh;
this->st_ynt_ = st_ynt;
this->st_ynh_ = st_ynh;
this->st_xnt_ = st_xnt;
this->st_xnh_ = st_xnh;
_copy_len = (_tile_len * sizeof(T)) % BYTE_ALIGN == 0 ? _tile_len : (_tile_len * sizeof(T) + (BYTE_ALIGN - _tile_len * sizeof(T) % BYTE_ALIGN)) / sizeof(T);
_half_copy_len = (_tile_len / 2 * sizeof(T)) % BYTE_ALIGN == 0 ? _tile_len / 2 : (_tile_len / 2 * sizeof(T) + (BYTE_ALIGN - _tile_len / 2 * sizeof(T) % BYTE_ALIGN)) / sizeof(T);
this->_st_ynt = st_ynt;
this->_st_ynh = st_ynh;
this->_st_xnt = st_xnt;
this->_st_xnh = st_xnh;
_copy_len = alignTileLen<T>(dh, BYTE_ALIGN);
_half_copy_len = alignTileLen<T>(dh, BYTE_ALIGN);
_block_idx = GetBlockIdx();
......@@ -96,12 +103,12 @@ __aicore__ inline void RoPEKernel<T, U>::Init(GM_ADDR y, GM_ADDR x, GM_ADDR pos,
}
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::CopyIn(int32_t i) {
__aicore__ inline void RoPEKernel<T, U>::copyIn(size_t i) {
LocalTensor<T> inputUb = inQue.AllocTensor<T>();
LocalTensor<T> sinUb = sinQue.AllocTensor<T>();
LocalTensor<T> cosUb = cosQue.AllocTensor<T>();
// Get idx of current tile in total input
auto idx = i * st_xnt_ + _block_idx * st_xnh_;
auto idx = i * _st_xnt + _block_idx * _st_xnh;
// Copy tile current tile into UB
DataCopy(inputUb, xGm[idx], _copy_len);
// Copy sin cos tile
......@@ -115,7 +122,7 @@ __aicore__ inline void RoPEKernel<T, U>::CopyIn(int32_t i) {
}
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::Compute(int32_t i) {
__aicore__ inline void RoPEKernel<T, U>::compute(size_t i) {
LocalTensor<T> inputUb = inQue.DeQue<T>();
LocalTensor<T> sinUb = sinQue.DeQue<T>();
LocalTensor<T> cosUb = cosQue.DeQue<T>();
......@@ -167,118 +174,130 @@ __aicore__ inline void RoPEKernel<T, U>::Compute(int32_t i) {
}
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::CopyOut(int32_t i) {
__aicore__ inline void RoPEKernel<T, U>::copyOut(size_t i) {
LocalTensor<T> outputUb = outQue.DeQue<T>();
auto idy = i * st_ynt_ + _block_idx * st_ynh_;
auto idy = i * _st_ynt + _block_idx * _st_ynh;
DataCopyExtParams params = {1, static_cast<uint32_t>(_tile_len * sizeof(T)), 0, 0, 0};
DataCopyPad(yGm[idy], outputUb, params);
outQue.FreeTensor(outputUb);
}
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::Process(int32_t nt) {
__aicore__ inline void RoPEKernel<T, U>::process(size_t seq_len) {
for (int32_t i = 0; i < nt; ++i) {
CopyIn(i);
Compute(i);
CopyOut(i);
for (size_t i = 0; i < seq_len; ++i) {
copyIn(i);
compute(i);
copyOut(i);
}
}
#define ROPE_KERNEL(TYPE, POSTYPE) \
switch (POSTYPE) { \
case 3: { \
case INFINI_DTYPE_I8: { \
RoPEKernel<TYPE, int8_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.process(seq_len); \
break; \
} \
case 4: { \
case INFINI_DTYPE_I16: { \
RoPEKernel<TYPE, int16_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.process(seq_len); \
break; \
} \
case 5: { \
case INFINI_DTYPE_I32: { \
RoPEKernel<TYPE, int32_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.process(seq_len); \
break; \
} \
case 6: { \
case INFINI_DTYPE_I64: { \
RoPEKernel<TYPE, int64_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.process(seq_len); \
break; \
} \
case 7: { \
case INFINI_DTYPE_U8: { \
RoPEKernel<TYPE, uint8_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.process(seq_len); \
break; \
} \
case 8: { \
case INFINI_DTYPE_U16: { \
RoPEKernel<TYPE, uint16_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.process(seq_len); \
break; \
} \
case 9: { \
case INFINI_DTYPE_U32: { \
RoPEKernel<TYPE, uint32_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.process(seq_len); \
break; \
} \
case 10: { \
case INFINI_DTYPE_U64: { \
RoPEKernel<TYPE, uint64_t> op; \
op.Init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.Process(seq_len); \
op.init(y, x, pos, sin, cos, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead); \
op.process(seq_len); \
break; \
} \
}
__global__ __aicore__ void rope_f16_kernel(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM_ADDR sin, GM_ADDR cos,
int32_t seq_len, int32_t dhead,
int32_t y_stride_seqlen, int32_t y_stride_nhead,
int32_t x_stride_seqlen, int32_t x_stride_nhead,
int32_t pos_type){
ROPE_KERNEL(half, pos_type)
#define DEFINE_ROPE_KERNEL(KERNEL_NAME, TYPE) \
__global__ __aicore__ void KERNEL_NAME(GM_ADDR y, \
GM_ADDR x, \
GM_ADDR pos, \
GM_ADDR sin, \
GM_ADDR cos, \
size_t seq_len, \
size_t dhead, \
ptrdiff_t y_stride_seqlen, \
ptrdiff_t y_stride_nhead, \
ptrdiff_t x_stride_seqlen, \
ptrdiff_t x_stride_nhead, \
int32_t pos_type) { \
ROPE_KERNEL(TYPE, pos_type) \
}
}
DEFINE_ROPE_KERNEL(rope_kernel_float, float)
DEFINE_ROPE_KERNEL(rope_kernel_half, half)
__global__ __aicore__ void rope_f32_kernel(GM_ADDR y, GM_ADDR x, GM_ADDR pos, GM_ADDR sin, GM_ADDR cos,
int32_t seq_len, int32_t dhead,
int32_t y_stride_seqlen, int32_t y_stride_nhead,
int32_t x_stride_seqlen, int32_t x_stride_nhead,
int32_t pos_type) {
ROPE_KERNEL(float, pos_type)
}
#undef DEFINE_ROPE_KERNEL
extern "C" infiniStatus_t rope_kernel_launch(void *y,
void *x,
void *pos,
void *sin,
void *cos,
int32_t seq_len,
int32_t nhead,
int32_t dhead,
int32_t data_type,
int32_t pos_type,
int32_t y_stride_seqlen,
int32_t y_stride_nhead,
int32_t x_stride_seqlen,
int32_t x_stride_nhead,
void *stream) {
switch (data_type) {
case 12: // float16
rope_f16_kernel<<<nhead, nullptr, stream>>>(y, x, pos, sin, cos, seq_len, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead, pos_type);
break;
case 13: // float32
rope_f32_kernel<<<nhead, nullptr, stream>>>(y, x, pos, sin, cos, seq_len, dhead, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead, pos_type);
break;
extern "C" infiniStatus_t rope_kernel_launch(
void *y,
void *x,
void *pos,
void *sin,
void *cos,
size_t seq_len,
size_t nhead,
size_t dhead,
infiniDtype_t dtype,
infiniDtype_t pos_type,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead,
void *stream) {
#define LAUNCH_ROPE_KERNEL(DTYPE_ENUM, KERNEL_NAME) \
case DTYPE_ENUM: \
KERNEL_NAME<<<nhead, nullptr, stream>>>(y, x, pos, sin, cos, \
seq_len, \
dhead, \
y_stride_seqlen, \
y_stride_nhead, \
x_stride_seqlen, \
x_stride_nhead, \
pos_type); \
return INFINI_STATUS_SUCCESS;
switch (dtype) {
LAUNCH_ROPE_KERNEL(INFINI_DTYPE_F16, rope_kernel_half)
LAUNCH_ROPE_KERNEL(INFINI_DTYPE_F32, rope_kernel_float)
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
......@@ -9,7 +9,7 @@
#include "cuda/rope_cuda.cuh"
#endif
#ifdef ENABLE_ASCEND_API
#include "ascend/rope_aclnn.h"
#include "ascend/rope_ascend.h"
#endif
__C infiniStatus_t infiniopCreateRoPEDescriptor(
......
......@@ -26,12 +26,6 @@ infiniStatus_t Descriptor::create(infiniopHandle_t handle, Descriptor **desc_ptr
return INFINI_STATUS_SUCCESS;
}
extern "C" infiniStatus_t swiglu_kernel_launch(
void *c, void *a, void *b,
infiniDtype_t dtype, size_t batch, size_t seq, size_t hd,
ptrdiff_t stride_batch_c, ptrdiff_t stride_batch_a, ptrdiff_t stride_batch_b,
ptrdiff_t stride_seq_c, ptrdiff_t stride_seq_a, ptrdiff_t stride_seq_b, void *stream);
infiniStatus_t Descriptor::calculate(void *workspace,
size_t workspace_size,
void *c,
......
......@@ -69,5 +69,11 @@ public:
void *stream) const;
};
extern "C" infiniStatus_t swiglu_kernel_launch(
void *c, void *a, void *b,
infiniDtype_t dtype, size_t batch, size_t seq, size_t hd,
ptrdiff_t stride_batch_c, ptrdiff_t stride_batch_a, ptrdiff_t stride_batch_b,
ptrdiff_t stride_seq_c, ptrdiff_t stride_seq_a, ptrdiff_t stride_seq_b, void *stream);
} // namespace op::swiglu::ascend
#endif // __ACLNN_SWIGLU_H__
......@@ -6,15 +6,20 @@ template <typename T>
class SwigluKernel {
public:
__aicore__ inline SwigluKernel() {}
__aicore__ inline void init(GM_ADDR c, GM_ADDR a, GM_ADDR b, int64_t batch_, int64_t seq, int64_t hd,
int64_t stride_batch_c, int64_t stride_batch_a, int64_t stride_batch_b,
int64_t stride_seq_c, int64_t stride_seq_a, int64_t stride_seq_b);
__aicore__ inline void init(GM_ADDR c, GM_ADDR a, GM_ADDR b,
size_t batch_, size_t seq, size_t hd,
ptrdiff_t stride_batch_c,
ptrdiff_t stride_batch_a,
ptrdiff_t stride_batch_b,
ptrdiff_t stride_seq_c,
ptrdiff_t stride_seq_a,
ptrdiff_t stride_seq_b);
__aicore__ inline void process();
private:
__aicore__ inline void copyIn(int64_t i);
__aicore__ inline void compute(int64_t i);
__aicore__ inline void copyOut(int64_t i);
__aicore__ inline void copyIn(size_t i);
__aicore__ inline void compute(size_t i);
__aicore__ inline void copyOut(size_t i);
private:
GlobalTensor<T> _c_gm, _a_gm, _b_gm;
......@@ -23,16 +28,21 @@ private:
TPipe _pipe;
float _beta_value = 1.0f;
int64_t _block_idx, _tile_len, _copy_len,
size_t _block_idx, _tile_len, _copy_len,
_batch, _seq_len, _hidden_size,
_stride_seq_a, _stride_seq_b, _stride_seq_c;
int64_t _stride_batch_a = 1, _stride_batch_b = 1, _stride_batch_c = 1;
};
template <typename T>
__aicore__ inline void SwigluKernel<T>::init(GM_ADDR c, GM_ADDR a, GM_ADDR b, int64_t batch_, int64_t seq, int64_t hd,
int64_t stride_batch_c, int64_t stride_batch_a, int64_t stride_batch_b,
int64_t stride_seq_c, int64_t stride_seq_a, int64_t stride_seq_b) {
__aicore__ inline void SwigluKernel<T>::init(GM_ADDR c, GM_ADDR a, GM_ADDR b,
size_t batch_, size_t seq, size_t hd,
ptrdiff_t stride_batch_c,
ptrdiff_t stride_batch_a,
ptrdiff_t stride_batch_b,
ptrdiff_t stride_seq_c,
ptrdiff_t stride_seq_a,
ptrdiff_t stride_seq_b) {
// Init Shape & StrideVariables
_batch = batch_;
_seq_len = seq;
......@@ -46,7 +56,7 @@ __aicore__ inline void SwigluKernel<T>::init(GM_ADDR c, GM_ADDR a, GM_ADDR b, in
_block_idx = GetBlockIdx();
_tile_len = _block_idx < (_hidden_size % BLOCK_NUM) ? (_hidden_size / BLOCK_NUM) + 1 : (_hidden_size / BLOCK_NUM);
_copy_len = (_tile_len * sizeof(T)) % BYTE_ALIGN == 0 ? _tile_len : (_tile_len * sizeof(T) + (BYTE_ALIGN - _tile_len * sizeof(T) % BYTE_ALIGN)) / sizeof(T);
_copy_len = alignTileLen<T>(_tile_len, BYTE_ALIGN);
// Set global tensor
_a_gm.SetGlobalBuffer((__gm__ T *)a);
......@@ -60,7 +70,7 @@ __aicore__ inline void SwigluKernel<T>::init(GM_ADDR c, GM_ADDR a, GM_ADDR b, in
}
template <typename T>
__aicore__ inline void SwigluKernel<T>::copyIn(int64_t i) {
__aicore__ inline void SwigluKernel<T>::copyIn(size_t i) {
// Alloc tensor from queue memory
LocalTensor<T> aLocal = _in_queue_a.AllocTensor<T>();
LocalTensor<T> bLocal = _in_queue_b.AllocTensor<T>();
......@@ -68,8 +78,8 @@ __aicore__ inline void SwigluKernel<T>::copyIn(int64_t i) {
auto batch_idx = _batch == 1 ? 0 : i / _seq_len;
auto seq_idx = _batch == 1 ? i : i % _seq_len;
int64_t idxa = batch_idx * _stride_batch_a + seq_idx * _stride_seq_a + _block_idx * _tile_len;
int64_t idxb = batch_idx * _stride_batch_b + seq_idx * _stride_seq_b + _block_idx * _tile_len;
ptrdiff_t idxa = batch_idx * _stride_batch_a + seq_idx * _stride_seq_a + _block_idx * _tile_len;
ptrdiff_t idxb = batch_idx * _stride_batch_b + seq_idx * _stride_seq_b + _block_idx * _tile_len;
// Copy process_th tile from global tensor to local tensor
DataCopy(aLocal, _a_gm[idxa], _copy_len);
DataCopy(bLocal, _b_gm[idxb], _copy_len);
......@@ -80,7 +90,7 @@ __aicore__ inline void SwigluKernel<T>::copyIn(int64_t i) {
}
template <typename T>
__aicore__ inline void SwigluKernel<T>::compute(int64_t i) {
__aicore__ inline void SwigluKernel<T>::compute(size_t i) {
// Deque input tensors from VECIN queue
LocalTensor<T> aLocal = _in_queue_a.DeQue<T>();
LocalTensor<T> bLocal = _in_queue_b.DeQue<T>();
......@@ -94,12 +104,12 @@ __aicore__ inline void SwigluKernel<T>::compute(int64_t i) {
}
template <typename T>
__aicore__ inline void SwigluKernel<T>::copyOut(int64_t i) {
__aicore__ inline void SwigluKernel<T>::copyOut(size_t i) {
// Deque output tensor from VECOUT queue
LocalTensor<T> cLocal = _out_queue_c.DeQue<T>();
auto batch_idx = _batch == 1 ? 0 : i / _seq_len;
auto seq_idx = _batch == 1 ? i : i % _seq_len;
int64_t idxc = batch_idx * _stride_batch_c + seq_idx * _stride_seq_c + _block_idx * _tile_len;
ptrdiff_t idxc = batch_idx * _stride_batch_c + seq_idx * _stride_seq_c + _block_idx * _tile_len;
// Copy progress_th tile from local tensor to global tensor
if (_tile_len * sizeof(T) % BYTE_ALIGN != 0) {
DataCopyExtParams dcep = {1, static_cast<uint32_t>(_tile_len * sizeof(T)), 0, 0, 0};
......@@ -113,28 +123,28 @@ __aicore__ inline void SwigluKernel<T>::copyOut(int64_t i) {
template <typename T>
__aicore__ inline void SwigluKernel<T>::process() {
for (int64_t i = 0; i < _batch * _seq_len; ++i) {
for (size_t i = 0; i < _batch * _seq_len; ++i) {
copyIn(i);
compute(i);
copyOut(i);
}
}
#define DEFINE_SWIGLU_KERNEL(KERNEL_NAME, TYPE) \
__global__ __aicore__ void KERNEL_NAME(GM_ADDR c, GM_ADDR a, GM_ADDR b, \
int64_t batch, int64_t seq, int64_t hd, \
int64_t stride_batch_c, \
int64_t stride_batch_a, \
int64_t stride_batch_b, \
int64_t stride_seq_c, \
int64_t stride_seq_a, \
int64_t stride_seq_b) { \
SwigluKernel<TYPE> op; \
op.init(c, a, b, \
batch, seq, hd, \
stride_batch_c, stride_batch_a, stride_batch_b, \
stride_seq_c, stride_seq_a, stride_seq_b); \
op.process(); \
#define DEFINE_SWIGLU_KERNEL(KERNEL_NAME, TYPE) \
__global__ __aicore__ void KERNEL_NAME(GM_ADDR c, GM_ADDR a, GM_ADDR b, \
size_t batch, size_t seq, size_t hd, \
ptrdiff_t stride_batch_c, \
ptrdiff_t stride_batch_a, \
ptrdiff_t stride_batch_b, \
ptrdiff_t stride_seq_c, \
ptrdiff_t stride_seq_a, \
ptrdiff_t stride_seq_b) { \
SwigluKernel<TYPE> op; \
op.init(c, a, b, \
batch, seq, hd, \
stride_batch_c, stride_batch_a, stride_batch_b, \
stride_seq_c, stride_seq_a, stride_seq_b); \
op.process(); \
}
DEFINE_SWIGLU_KERNEL(swiglu_kernel_half, half)
......@@ -152,9 +162,9 @@ extern "C" infiniStatus_t swiglu_kernel_launch(
case DTYPE_ENUM: \
KERNEL_NAME<<<BLOCK_NUM, nullptr, stream>>>( \
c, a, b, \
static_cast<int64_t>(batch), \
static_cast<int64_t>(seq), \
static_cast<int64_t>(hd), \
batch, \
seq, \
hd, \
stride_batch_c, stride_batch_a, stride_batch_b, \
stride_seq_c, stride_seq_a, stride_seq_b); \
return INFINI_STATUS_SUCCESS;
......
......@@ -189,6 +189,9 @@ def test(
)
lib_rope()
if sync is not None:
sync()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment