Commit fafb22db authored by zhangyue's avatar zhangyue
Browse files

issue/9: 根据review 修改

parent 120b4348
...@@ -25,7 +25,7 @@ include_directories( ...@@ -25,7 +25,7 @@ include_directories(
ascendc_library(ascend_kernels STATIC ascendc_library(ascend_kernels STATIC
../../ops/swiglu/ascend/swiglu_kernel.cpp ../../ops/swiglu/ascend/swiglu_ascend_kernel.cpp
# ../../ops/rotary_embedding/ascend/rotary_embedding_kernel.cpp # ../../ops/rotary_embedding/ascend/rotary_embedding_kernel.cpp
# ../../ops/random_sample/ascend/random_sample_kernel.cpp # ../../ops/random_sample/ascend/random_sample_kernel.cpp
) )
......
#include "causal_softmax_aclnn.h" #include "causal_softmax_ascend.h"
#include "../../../devices/ascend/common_ascend.h" #include "../../../devices/ascend/common_ascend.h"
#include <aclnnop/aclnn_masked_fill_tensor.h> #include <aclnnop/aclnn_masked_fill_tensor.h>
#include <aclnnop/aclnn_softmax.h> #include <aclnnop/aclnn_softmax.h>
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "cuda/causal_softmax_cuda.cuh" #include "cuda/causal_softmax_cuda.cuh"
#endif #endif
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
#include "ascend/causal_softmax_aclnn.h" #include "ascend/causal_softmax_ascend.h"
#endif #endif
__C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor( __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
......
#include "swiglu_aclnn.h" #include "swiglu_ascend.h"
#include "../../../devices/ascend/common_ascend.h" #include "../../../devices/ascend/common_ascend.h"
namespace op::swiglu::ascend { namespace op::swiglu::ascend {
......
...@@ -20,23 +20,22 @@ public: ...@@ -20,23 +20,22 @@ public:
std::vector<ptrdiff_t> b_strides; std::vector<ptrdiff_t> b_strides;
static utils::Result<SwigluInfo> create(infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc) { static utils::Result<SwigluInfo> create(infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc) {
if (!c_desc || !a_desc || !b_desc) { CHECK_OR_RETURN(c_desc && a_desc && b_desc, INFINI_STATUS_BAD_PARAM);
return INFINI_STATUS_BAD_PARAM; CHECK_OR_RETURN(!c_desc->hasBroadcastDim(), INFINI_STATUS_BAD_TENSOR_STRIDES);
} CHECK_OR_RETURN(c_desc->ndim() == a_desc->ndim()
if (c_desc->hasBroadcastDim()) { && c_desc->ndim() == b_desc->ndim()
return INFINI_STATUS_BAD_TENSOR_STRIDES; && (c_desc->ndim() == 2 || c_desc->ndim() == 3),
} INFINI_STATUS_BAD_TENSOR_SHAPE);
if (c_desc->ndim() != a_desc->ndim() || c_desc->ndim() != b_desc->ndim() || (c_desc->ndim() != 2 && c_desc->ndim() != 3)) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
CHECK_SAME_SHAPE(c_desc->shape(), a_desc->shape(), b_desc->shape()); CHECK_SAME_SHAPE(c_desc->shape(), a_desc->shape(), b_desc->shape());
int32_t ndim = c_desc->ndim(); int32_t ndim = c_desc->ndim();
if (c_desc->stride(ndim - 1) != 1 || a_desc->stride(ndim - 1) != 1 || b_desc->stride(ndim - 1) != 1) { CHECK_OR_RETURN(c_desc->stride(ndim - 1) == 1
return INFINI_STATUS_BAD_TENSOR_STRIDES; && a_desc->stride(ndim - 1) == 1
} && b_desc->stride(ndim - 1) == 1,
if (c_desc->dtype() != a_desc->dtype() || c_desc->dtype() != b_desc->dtype()) { INFINI_STATUS_BAD_TENSOR_STRIDES);
return INFINI_STATUS_BAD_TENSOR_DTYPE; CHECK_OR_RETURN(c_desc->dtype() == a_desc->dtype()
} && c_desc->dtype() == b_desc->dtype(),
INFINI_STATUS_BAD_TENSOR_DTYPE);
return utils::Result<SwigluInfo>(SwigluInfo{ return utils::Result<SwigluInfo>(SwigluInfo{
c_desc->dtype(), c_desc->dtype(),
c_desc->shape(), c_desc->shape(),
......
...@@ -6,117 +6,117 @@ template <typename T> ...@@ -6,117 +6,117 @@ template <typename T>
class SwigluKernel { class SwigluKernel {
public: public:
__aicore__ inline SwigluKernel() {} __aicore__ inline SwigluKernel() {}
__aicore__ inline void Init(GM_ADDR c, GM_ADDR a, GM_ADDR b, int64_t batch_, int64_t seq, int64_t hd, __aicore__ inline void init(GM_ADDR c, GM_ADDR a, GM_ADDR b, int64_t batch_, int64_t seq, int64_t hd,
int64_t stride_batch_c, int64_t stride_batch_a, int64_t stride_batch_b, int64_t stride_batch_c, int64_t stride_batch_a, int64_t stride_batch_b,
int64_t stride_seq_c, int64_t stride_seq_a, int64_t stride_seq_b); int64_t stride_seq_c, int64_t stride_seq_a, int64_t stride_seq_b);
__aicore__ inline void Process(); __aicore__ inline void process();
private: private:
__aicore__ inline void CopyIn(int64_t i); __aicore__ inline void copyIn(int64_t i);
__aicore__ inline void Compute(int64_t i); __aicore__ inline void compute(int64_t i);
__aicore__ inline void CopyOut(int64_t i); __aicore__ inline void copyOut(int64_t i);
private: private:
GlobalTensor<T> cGm, aGm, bGm; GlobalTensor<T> _c_gm, _a_gm, _b_gm;
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueA, inQueueB; TQue<QuePosition::VECIN, BUFFER_NUM> _in_queue_a, _in_queue_b;
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueC; TQue<QuePosition::VECOUT, BUFFER_NUM> _out_queue_c;
TPipe pipe; TPipe _pipe;
float _beta_value = 1.0f; float _beta_value = 1.0f;
int64_t _block_idx, _tile_len, _copy_len, int64_t _block_idx, _tile_len, _copy_len,
batch, seq_len, hidden_size, _batch, _seq_len, _hidden_size,
strideSeqA, strideSeqB, strideSeqC; _stride_seq_a, _stride_seq_b, _stride_seq_c;
int64_t strideBatchA = 1, strideBatchB = 1, strideBatchC = 1; int64_t _stride_batch_a = 1, _stride_batch_b = 1, _stride_batch_c = 1;
}; };
template <typename T> template <typename T>
__aicore__ inline void SwigluKernel<T>::Init(GM_ADDR c, GM_ADDR a, GM_ADDR b, int64_t batch_, int64_t seq, int64_t hd, __aicore__ inline void SwigluKernel<T>::init(GM_ADDR c, GM_ADDR a, GM_ADDR b, int64_t batch_, int64_t seq, int64_t hd,
int64_t stride_batch_c, int64_t stride_batch_a, int64_t stride_batch_b, int64_t stride_batch_c, int64_t stride_batch_a, int64_t stride_batch_b,
int64_t stride_seq_c, int64_t stride_seq_a, int64_t stride_seq_b) { int64_t stride_seq_c, int64_t stride_seq_a, int64_t stride_seq_b) {
// Init Shape & StrideVariables // Init Shape & StrideVariables
batch = batch_; _batch = batch_;
seq_len = seq; _seq_len = seq;
hidden_size = hd; _hidden_size = hd;
strideBatchA = stride_batch_a; _stride_batch_a = stride_batch_a;
strideBatchB = stride_batch_b; _stride_batch_b = stride_batch_b;
strideBatchC = stride_batch_c; _stride_batch_c = stride_batch_c;
strideSeqA = stride_seq_a; _stride_seq_a = stride_seq_a;
strideSeqB = stride_seq_b; _stride_seq_b = stride_seq_b;
strideSeqC = stride_seq_c; _stride_seq_c = stride_seq_c;
_block_idx = GetBlockIdx(); _block_idx = GetBlockIdx();
_tile_len = _block_idx < (hidden_size % BLOCK_NUM) ? (hidden_size / BLOCK_NUM) + 1 : (hidden_size / BLOCK_NUM); _tile_len = _block_idx < (_hidden_size % BLOCK_NUM) ? (_hidden_size / BLOCK_NUM) + 1 : (_hidden_size / BLOCK_NUM);
_copy_len = (_tile_len * sizeof(T)) % BYTE_ALIGN == 0 ? _tile_len : (_tile_len * sizeof(T) + (BYTE_ALIGN - _tile_len * sizeof(T) % BYTE_ALIGN)) / sizeof(T); _copy_len = (_tile_len * sizeof(T)) % BYTE_ALIGN == 0 ? _tile_len : (_tile_len * sizeof(T) + (BYTE_ALIGN - _tile_len * sizeof(T) % BYTE_ALIGN)) / sizeof(T);
// Set global tensor // Set global tensor
aGm.SetGlobalBuffer((__gm__ T *)a); _a_gm.SetGlobalBuffer((__gm__ T *)a);
bGm.SetGlobalBuffer((__gm__ T *)b); _b_gm.SetGlobalBuffer((__gm__ T *)b);
cGm.SetGlobalBuffer((__gm__ T *)c); _c_gm.SetGlobalBuffer((__gm__ T *)c);
// Pipe alloc memory to queue, the unit is bytes // _pipe alloc memory to queue, the unit is bytes
pipe.InitBuffer(inQueueA, BUFFER_NUM, _copy_len * sizeof(T)); _pipe.InitBuffer(_in_queue_a, BUFFER_NUM, _copy_len * sizeof(T));
pipe.InitBuffer(inQueueB, BUFFER_NUM, _copy_len * sizeof(T)); _pipe.InitBuffer(_in_queue_b, BUFFER_NUM, _copy_len * sizeof(T));
pipe.InitBuffer(outQueueC, BUFFER_NUM, _copy_len * sizeof(T)); _pipe.InitBuffer(_out_queue_c, BUFFER_NUM, _copy_len * sizeof(T));
} }
template <typename T> template <typename T>
__aicore__ inline void SwigluKernel<T>::CopyIn(int64_t i) { __aicore__ inline void SwigluKernel<T>::copyIn(int64_t i) {
// Alloc tensor from queue memory // Alloc tensor from queue memory
LocalTensor<T> aLocal = inQueueA.AllocTensor<T>(); LocalTensor<T> aLocal = _in_queue_a.AllocTensor<T>();
LocalTensor<T> bLocal = inQueueB.AllocTensor<T>(); LocalTensor<T> bLocal = _in_queue_b.AllocTensor<T>();
// Get idx of current tile // Get idx of current tile
auto batchIdx = batch == 1 ? 0 : i / seq_len; auto batch_idx = _batch == 1 ? 0 : i / _seq_len;
auto seqIdx = batch == 1 ? i : i % seq_len; auto seq_idx = _batch == 1 ? i : i % _seq_len;
int64_t idxa = batchIdx * strideBatchA + seqIdx * strideSeqA + _block_idx * _tile_len; int64_t idxa = batch_idx * _stride_batch_a + seq_idx * _stride_seq_a + _block_idx * _tile_len;
int64_t idxb = batchIdx * strideBatchB + seqIdx * strideSeqB + _block_idx * _tile_len; int64_t idxb = batch_idx * _stride_batch_b + seq_idx * _stride_seq_b + _block_idx * _tile_len;
// Copy process_th tile from global tensor to local tensor // Copy process_th tile from global tensor to local tensor
DataCopy(aLocal, aGm[idxa], _copy_len); DataCopy(aLocal, _a_gm[idxa], _copy_len);
DataCopy(bLocal, bGm[idxb], _copy_len); DataCopy(bLocal, _b_gm[idxb], _copy_len);
// Enque input tensor to VECIN queue // Enque input tensor to VECIN queue
inQueueA.EnQue(aLocal); _in_queue_a.EnQue(aLocal);
inQueueB.EnQue(bLocal); _in_queue_b.EnQue(bLocal);
} }
template <typename T> template <typename T>
__aicore__ inline void SwigluKernel<T>::Compute(int64_t i) { __aicore__ inline void SwigluKernel<T>::compute(int64_t i) {
// Deque input tensors from VECIN queue // Deque input tensors from VECIN queue
LocalTensor<T> aLocal = inQueueA.DeQue<T>(); LocalTensor<T> aLocal = _in_queue_a.DeQue<T>();
LocalTensor<T> bLocal = inQueueB.DeQue<T>(); LocalTensor<T> bLocal = _in_queue_b.DeQue<T>();
LocalTensor<T> cLocal = outQueueC.AllocTensor<T>(); LocalTensor<T> cLocal = _out_queue_c.AllocTensor<T>();
// Call SwiGLU ascend api // Call SwiGLU ascend api
SwiGLU<T, false>(cLocal, aLocal, bLocal, _beta_value, _copy_len); SwiGLU<T, false>(cLocal, aLocal, bLocal, _beta_value, _copy_len);
// Enque result and free input // Enque result and free input
outQueueC.EnQue<T>(cLocal); _out_queue_c.EnQue<T>(cLocal);
inQueueA.FreeTensor(aLocal); _in_queue_a.FreeTensor(aLocal);
inQueueB.FreeTensor(bLocal); _in_queue_b.FreeTensor(bLocal);
} }
template <typename T> template <typename T>
__aicore__ inline void SwigluKernel<T>::CopyOut(int64_t i) { __aicore__ inline void SwigluKernel<T>::copyOut(int64_t i) {
// Deque output tensor from VECOUT queue // Deque output tensor from VECOUT queue
LocalTensor<T> cLocal = outQueueC.DeQue<T>(); LocalTensor<T> cLocal = _out_queue_c.DeQue<T>();
auto batchIdx = batch == 1 ? 0 : i / seq_len; auto batch_idx = _batch == 1 ? 0 : i / _seq_len;
auto seqIdx = batch == 1 ? i : i % seq_len; auto seq_idx = _batch == 1 ? i : i % _seq_len;
int64_t idxc = batchIdx * strideBatchC + seqIdx * strideSeqC + _block_idx * _tile_len; int64_t idxc = batch_idx * _stride_batch_c + seq_idx * _stride_seq_c + _block_idx * _tile_len;
// Copy progress_th tile from local tensor to global tensor // Copy progress_th tile from local tensor to global tensor
if (_tile_len * sizeof(T) % BYTE_ALIGN != 0) { if (_tile_len * sizeof(T) % BYTE_ALIGN != 0) {
DataCopyExtParams dcep = {1, static_cast<uint32_t>(_tile_len * sizeof(T)), 0, 0, 0}; DataCopyExtParams dcep = {1, static_cast<uint32_t>(_tile_len * sizeof(T)), 0, 0, 0};
DataCopyPad(cGm[idxc], cLocal, dcep); DataCopyPad(_c_gm[idxc], cLocal, dcep);
} else { } else {
DataCopy(cGm[idxc], cLocal, _tile_len); DataCopy(_c_gm[idxc], cLocal, _tile_len);
} }
// Free output Local tensor // Free output Local tensor
outQueueC.FreeTensor(cLocal); _out_queue_c.FreeTensor(cLocal);
} }
template <typename T> template <typename T>
__aicore__ inline void SwigluKernel<T>::Process() { __aicore__ inline void SwigluKernel<T>::process() {
for (int64_t i = 0; i < batch * seq_len; ++i) { for (int64_t i = 0; i < _batch * _seq_len; ++i) {
CopyIn(i); copyIn(i);
Compute(i); compute(i);
CopyOut(i); copyOut(i);
} }
} }
...@@ -130,11 +130,11 @@ __aicore__ inline void SwigluKernel<T>::Process() { ...@@ -130,11 +130,11 @@ __aicore__ inline void SwigluKernel<T>::Process() {
int64_t stride_seq_a, \ int64_t stride_seq_a, \
int64_t stride_seq_b) { \ int64_t stride_seq_b) { \
SwigluKernel<TYPE> op; \ SwigluKernel<TYPE> op; \
op.Init(c, a, b, \ op.init(c, a, b, \
batch, seq, hd, \ batch, seq, hd, \
stride_batch_c, stride_batch_a, stride_batch_b, \ stride_batch_c, stride_batch_a, stride_batch_b, \
stride_seq_c, stride_seq_a, stride_seq_b); \ stride_seq_c, stride_seq_a, stride_seq_b); \
op.Process(); \ op.process(); \
} }
DEFINE_SWIGLU_KERNEL(swiglu_kernel_half, half) DEFINE_SWIGLU_KERNEL(swiglu_kernel_half, half)
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#include "kunlun/swiglu_kunlun.h" #include "kunlun/swiglu_kunlun.h"
#endif #endif
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
#include "ascend/swiglu_aclnn.h" #include "ascend/swiglu_ascend.h"
#endif #endif
__C infiniStatus_t infiniopCreateSwiGLUDescriptor( __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment