Commit fafb22db authored by zhangyue's avatar zhangyue
Browse files

issue/9: 根据review 修改

parent 120b4348
......@@ -25,7 +25,7 @@ include_directories(
ascendc_library(ascend_kernels STATIC
../../ops/swiglu/ascend/swiglu_kernel.cpp
../../ops/swiglu/ascend/swiglu_ascend_kernel.cpp
# ../../ops/rotary_embedding/ascend/rotary_embedding_kernel.cpp
# ../../ops/random_sample/ascend/random_sample_kernel.cpp
)
......
#include "causal_softmax_aclnn.h"
#include "causal_softmax_ascend.h"
#include "../../../devices/ascend/common_ascend.h"
#include <aclnnop/aclnn_masked_fill_tensor.h>
#include <aclnnop/aclnn_softmax.h>
......
......@@ -9,7 +9,7 @@
#include "cuda/causal_softmax_cuda.cuh"
#endif
#ifdef ENABLE_ASCEND_API
#include "ascend/causal_softmax_aclnn.h"
#include "ascend/causal_softmax_ascend.h"
#endif
__C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
......
#include "swiglu_aclnn.h"
#include "swiglu_ascend.h"
#include "../../../devices/ascend/common_ascend.h"
namespace op::swiglu::ascend {
......
......@@ -20,23 +20,22 @@ public:
std::vector<ptrdiff_t> b_strides;
static utils::Result<SwigluInfo> create(infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc) {
if (!c_desc || !a_desc || !b_desc) {
return INFINI_STATUS_BAD_PARAM;
}
if (c_desc->hasBroadcastDim()) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
if (c_desc->ndim() != a_desc->ndim() || c_desc->ndim() != b_desc->ndim() || (c_desc->ndim() != 2 && c_desc->ndim() != 3)) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
CHECK_OR_RETURN(c_desc && a_desc && b_desc, INFINI_STATUS_BAD_PARAM);
CHECK_OR_RETURN(!c_desc->hasBroadcastDim(), INFINI_STATUS_BAD_TENSOR_STRIDES);
CHECK_OR_RETURN(c_desc->ndim() == a_desc->ndim()
&& c_desc->ndim() == b_desc->ndim()
&& (c_desc->ndim() == 2 || c_desc->ndim() == 3),
INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_SAME_SHAPE(c_desc->shape(), a_desc->shape(), b_desc->shape());
int32_t ndim = c_desc->ndim();
if (c_desc->stride(ndim - 1) != 1 || a_desc->stride(ndim - 1) != 1 || b_desc->stride(ndim - 1) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
if (c_desc->dtype() != a_desc->dtype() || c_desc->dtype() != b_desc->dtype()) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
CHECK_OR_RETURN(c_desc->stride(ndim - 1) == 1
&& a_desc->stride(ndim - 1) == 1
&& b_desc->stride(ndim - 1) == 1,
INFINI_STATUS_BAD_TENSOR_STRIDES);
CHECK_OR_RETURN(c_desc->dtype() == a_desc->dtype()
&& c_desc->dtype() == b_desc->dtype(),
INFINI_STATUS_BAD_TENSOR_DTYPE);
return utils::Result<SwigluInfo>(SwigluInfo{
c_desc->dtype(),
c_desc->shape(),
......
......@@ -6,117 +6,117 @@ template <typename T>
class SwigluKernel {
public:
__aicore__ inline SwigluKernel() {}
__aicore__ inline void Init(GM_ADDR c, GM_ADDR a, GM_ADDR b, int64_t batch_, int64_t seq, int64_t hd,
__aicore__ inline void init(GM_ADDR c, GM_ADDR a, GM_ADDR b, int64_t batch_, int64_t seq, int64_t hd,
int64_t stride_batch_c, int64_t stride_batch_a, int64_t stride_batch_b,
int64_t stride_seq_c, int64_t stride_seq_a, int64_t stride_seq_b);
__aicore__ inline void Process();
__aicore__ inline void process();
private:
__aicore__ inline void CopyIn(int64_t i);
__aicore__ inline void Compute(int64_t i);
__aicore__ inline void CopyOut(int64_t i);
__aicore__ inline void copyIn(int64_t i);
__aicore__ inline void compute(int64_t i);
__aicore__ inline void copyOut(int64_t i);
private:
GlobalTensor<T> cGm, aGm, bGm;
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueA, inQueueB;
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueC;
GlobalTensor<T> _c_gm, _a_gm, _b_gm;
TQue<QuePosition::VECIN, BUFFER_NUM> _in_queue_a, _in_queue_b;
TQue<QuePosition::VECOUT, BUFFER_NUM> _out_queue_c;
TPipe pipe;
TPipe _pipe;
float _beta_value = 1.0f;
int64_t _block_idx, _tile_len, _copy_len,
batch, seq_len, hidden_size,
strideSeqA, strideSeqB, strideSeqC;
int64_t strideBatchA = 1, strideBatchB = 1, strideBatchC = 1;
_batch, _seq_len, _hidden_size,
_stride_seq_a, _stride_seq_b, _stride_seq_c;
int64_t _stride_batch_a = 1, _stride_batch_b = 1, _stride_batch_c = 1;
};
template <typename T>
__aicore__ inline void SwigluKernel<T>::Init(GM_ADDR c, GM_ADDR a, GM_ADDR b, int64_t batch_, int64_t seq, int64_t hd,
__aicore__ inline void SwigluKernel<T>::init(GM_ADDR c, GM_ADDR a, GM_ADDR b, int64_t batch_, int64_t seq, int64_t hd,
int64_t stride_batch_c, int64_t stride_batch_a, int64_t stride_batch_b,
int64_t stride_seq_c, int64_t stride_seq_a, int64_t stride_seq_b) {
// Init Shape & StrideVariables
batch = batch_;
seq_len = seq;
hidden_size = hd;
strideBatchA = stride_batch_a;
strideBatchB = stride_batch_b;
strideBatchC = stride_batch_c;
strideSeqA = stride_seq_a;
strideSeqB = stride_seq_b;
strideSeqC = stride_seq_c;
_batch = batch_;
_seq_len = seq;
_hidden_size = hd;
_stride_batch_a = stride_batch_a;
_stride_batch_b = stride_batch_b;
_stride_batch_c = stride_batch_c;
_stride_seq_a = stride_seq_a;
_stride_seq_b = stride_seq_b;
_stride_seq_c = stride_seq_c;
_block_idx = GetBlockIdx();
_tile_len = _block_idx < (hidden_size % BLOCK_NUM) ? (hidden_size / BLOCK_NUM) + 1 : (hidden_size / BLOCK_NUM);
_tile_len = _block_idx < (_hidden_size % BLOCK_NUM) ? (_hidden_size / BLOCK_NUM) + 1 : (_hidden_size / BLOCK_NUM);
_copy_len = (_tile_len * sizeof(T)) % BYTE_ALIGN == 0 ? _tile_len : (_tile_len * sizeof(T) + (BYTE_ALIGN - _tile_len * sizeof(T) % BYTE_ALIGN)) / sizeof(T);
// Set global tensor
aGm.SetGlobalBuffer((__gm__ T *)a);
bGm.SetGlobalBuffer((__gm__ T *)b);
cGm.SetGlobalBuffer((__gm__ T *)c);
// Pipe alloc memory to queue, the unit is bytes
pipe.InitBuffer(inQueueA, BUFFER_NUM, _copy_len * sizeof(T));
pipe.InitBuffer(inQueueB, BUFFER_NUM, _copy_len * sizeof(T));
pipe.InitBuffer(outQueueC, BUFFER_NUM, _copy_len * sizeof(T));
_a_gm.SetGlobalBuffer((__gm__ T *)a);
_b_gm.SetGlobalBuffer((__gm__ T *)b);
_c_gm.SetGlobalBuffer((__gm__ T *)c);
// _pipe alloc memory to queue, the unit is bytes
_pipe.InitBuffer(_in_queue_a, BUFFER_NUM, _copy_len * sizeof(T));
_pipe.InitBuffer(_in_queue_b, BUFFER_NUM, _copy_len * sizeof(T));
_pipe.InitBuffer(_out_queue_c, BUFFER_NUM, _copy_len * sizeof(T));
}
template <typename T>
__aicore__ inline void SwigluKernel<T>::CopyIn(int64_t i) {
__aicore__ inline void SwigluKernel<T>::copyIn(int64_t i) {
// Alloc tensor from queue memory
LocalTensor<T> aLocal = inQueueA.AllocTensor<T>();
LocalTensor<T> bLocal = inQueueB.AllocTensor<T>();
LocalTensor<T> aLocal = _in_queue_a.AllocTensor<T>();
LocalTensor<T> bLocal = _in_queue_b.AllocTensor<T>();
// Get idx of current tile
auto batchIdx = batch == 1 ? 0 : i / seq_len;
auto seqIdx = batch == 1 ? i : i % seq_len;
auto batch_idx = _batch == 1 ? 0 : i / _seq_len;
auto seq_idx = _batch == 1 ? i : i % _seq_len;
int64_t idxa = batchIdx * strideBatchA + seqIdx * strideSeqA + _block_idx * _tile_len;
int64_t idxb = batchIdx * strideBatchB + seqIdx * strideSeqB + _block_idx * _tile_len;
int64_t idxa = batch_idx * _stride_batch_a + seq_idx * _stride_seq_a + _block_idx * _tile_len;
int64_t idxb = batch_idx * _stride_batch_b + seq_idx * _stride_seq_b + _block_idx * _tile_len;
// Copy process_th tile from global tensor to local tensor
DataCopy(aLocal, aGm[idxa], _copy_len);
DataCopy(bLocal, bGm[idxb], _copy_len);
DataCopy(aLocal, _a_gm[idxa], _copy_len);
DataCopy(bLocal, _b_gm[idxb], _copy_len);
// Enque input tensor to VECIN queue
inQueueA.EnQue(aLocal);
inQueueB.EnQue(bLocal);
_in_queue_a.EnQue(aLocal);
_in_queue_b.EnQue(bLocal);
}
template <typename T>
__aicore__ inline void SwigluKernel<T>::Compute(int64_t i) {
__aicore__ inline void SwigluKernel<T>::compute(int64_t i) {
// Deque input tensors from VECIN queue
LocalTensor<T> aLocal = inQueueA.DeQue<T>();
LocalTensor<T> bLocal = inQueueB.DeQue<T>();
LocalTensor<T> cLocal = outQueueC.AllocTensor<T>();
LocalTensor<T> aLocal = _in_queue_a.DeQue<T>();
LocalTensor<T> bLocal = _in_queue_b.DeQue<T>();
LocalTensor<T> cLocal = _out_queue_c.AllocTensor<T>();
// Call SwiGLU ascend api
SwiGLU<T, false>(cLocal, aLocal, bLocal, _beta_value, _copy_len);
// Enque result and free input
outQueueC.EnQue<T>(cLocal);
inQueueA.FreeTensor(aLocal);
inQueueB.FreeTensor(bLocal);
_out_queue_c.EnQue<T>(cLocal);
_in_queue_a.FreeTensor(aLocal);
_in_queue_b.FreeTensor(bLocal);
}
template <typename T>
__aicore__ inline void SwigluKernel<T>::CopyOut(int64_t i) {
__aicore__ inline void SwigluKernel<T>::copyOut(int64_t i) {
// Deque output tensor from VECOUT queue
LocalTensor<T> cLocal = outQueueC.DeQue<T>();
auto batchIdx = batch == 1 ? 0 : i / seq_len;
auto seqIdx = batch == 1 ? i : i % seq_len;
int64_t idxc = batchIdx * strideBatchC + seqIdx * strideSeqC + _block_idx * _tile_len;
LocalTensor<T> cLocal = _out_queue_c.DeQue<T>();
auto batch_idx = _batch == 1 ? 0 : i / _seq_len;
auto seq_idx = _batch == 1 ? i : i % _seq_len;
int64_t idxc = batch_idx * _stride_batch_c + seq_idx * _stride_seq_c + _block_idx * _tile_len;
// Copy progress_th tile from local tensor to global tensor
if (_tile_len * sizeof(T) % BYTE_ALIGN != 0) {
DataCopyExtParams dcep = {1, static_cast<uint32_t>(_tile_len * sizeof(T)), 0, 0, 0};
DataCopyPad(cGm[idxc], cLocal, dcep);
DataCopyPad(_c_gm[idxc], cLocal, dcep);
} else {
DataCopy(cGm[idxc], cLocal, _tile_len);
DataCopy(_c_gm[idxc], cLocal, _tile_len);
}
// Free output Local tensor
outQueueC.FreeTensor(cLocal);
_out_queue_c.FreeTensor(cLocal);
}
template <typename T>
__aicore__ inline void SwigluKernel<T>::Process() {
for (int64_t i = 0; i < batch * seq_len; ++i) {
CopyIn(i);
Compute(i);
CopyOut(i);
__aicore__ inline void SwigluKernel<T>::process() {
for (int64_t i = 0; i < _batch * _seq_len; ++i) {
copyIn(i);
compute(i);
copyOut(i);
}
}
......@@ -130,11 +130,11 @@ __aicore__ inline void SwigluKernel<T>::Process() {
int64_t stride_seq_a, \
int64_t stride_seq_b) { \
SwigluKernel<TYPE> op; \
op.Init(c, a, b, \
op.init(c, a, b, \
batch, seq, hd, \
stride_batch_c, stride_batch_a, stride_batch_b, \
stride_seq_c, stride_seq_a, stride_seq_b); \
op.Process(); \
op.process(); \
}
DEFINE_SWIGLU_KERNEL(swiglu_kernel_half, half)
......
......@@ -12,7 +12,7 @@
#include "kunlun/swiglu_kunlun.h"
#endif
#ifdef ENABLE_ASCEND_API
#include "ascend/swiglu_aclnn.h"
#include "ascend/swiglu_ascend.h"
#endif
__C infiniStatus_t infiniopCreateSwiGLUDescriptor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment