"googlemock/vscode:/vscode.git/clone" did not exist on "717ce7feb87278b81dad756d973693024621f8e8"
Commit d7385575 authored by zhangyue's avatar zhangyue
Browse files

issue/9: fix review comment

parent 6e1491bd
#ifndef __INFINIOP_ASCEND_KERNEL_COMMON_H__
#define __INFINIOP_ASCEND_KERNEL_COMMON_H__
#include "../../../../include/infinicore.h"
#include "kernel_operator.h"
constexpr int32_t BLOCK_NUM = 8;
constexpr int32_t BUFFER_NUM = 2;
constexpr int32_t BYTE_ALIGN = 32;
#endif
...@@ -26,25 +26,26 @@ infiniStatus_t Descriptor::create(infiniopHandle_t handle, Descriptor **desc_ptr ...@@ -26,25 +26,26 @@ infiniStatus_t Descriptor::create(infiniopHandle_t handle, Descriptor **desc_ptr
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
extern "C" infiniStatus_t swiglu_kernel_launch(void *c, void *a, void *b, extern "C" infiniStatus_t swiglu_kernel_launch(
int dtype, int batch, int seq, int hd, void *c, void *a, void *b,
int stride_batch_c, int stride_batch_a, int stride_batch_b, infiniDtype_t dtype, size_t batch, size_t seq, size_t hd,
int stride_seq_c, int stride_seq_a, int stride_seq_b, void *stream); ptrdiff_t stride_batch_c, ptrdiff_t stride_batch_a, ptrdiff_t stride_batch_b,
ptrdiff_t stride_seq_c, ptrdiff_t stride_seq_a, ptrdiff_t stride_seq_b, void *stream);
infiniStatus_t Descriptor::calculate(void *workspace, infiniStatus_t Descriptor::calculate(void *workspace,
size_t workspace_size, size_t workspace_size,
void *c, void *c,
std::vector<const void *> inputs, std::vector<const void *> inputs,
void *stream) const { void *stream) const {
int batch = _info.ndim == 2 ? 1 : _info.shape[0]; auto batch = _info.ndim == 2 ? 1 : _info.shape[0];
int seq_len = _info.ndim == 2 ? _info.shape[0] : _info.shape[1]; auto seq_len = _info.ndim == 2 ? _info.shape[0] : _info.shape[1];
int hidden_size = _info.shape[_info.ndim - 1]; auto hidden_size = _info.shape[_info.ndim - 1];
int stride_batch_c = _info.ndim == 2 ? 1 : _info.c_strides[0]; auto stride_batch_c = _info.ndim == 2 ? 1 : _info.c_strides[0];
int stride_batch_a = _info.ndim == 2 ? 1 : _info.a_strides[0]; auto stride_batch_a = _info.ndim == 2 ? 1 : _info.a_strides[0];
int stride_batch_b = _info.ndim == 2 ? 1 : _info.b_strides[0]; auto stride_batch_b = _info.ndim == 2 ? 1 : _info.b_strides[0];
int stride_seq_c = _info.ndim == 2 ? _info.c_strides[0] : _info.c_strides[1]; auto stride_seq_c = _info.ndim == 2 ? _info.c_strides[0] : _info.c_strides[1];
int stride_seq_a = _info.ndim == 2 ? _info.a_strides[0] : _info.a_strides[1]; auto stride_seq_a = _info.ndim == 2 ? _info.a_strides[0] : _info.a_strides[1];
int stride_seq_b = _info.ndim == 2 ? _info.b_strides[0] : _info.b_strides[1]; auto stride_seq_b = _info.ndim == 2 ? _info.b_strides[0] : _info.b_strides[1];
auto status = swiglu_kernel_launch(c, (void *)inputs[0], (void *)inputs[1], _info.dtype, batch, seq_len, hidden_size, stride_batch_c, stride_batch_a, stride_batch_b, stride_seq_c, stride_seq_a, stride_seq_b, stream); auto status = swiglu_kernel_launch(c, (void *)inputs[0], (void *)inputs[1], _info.dtype, batch, seq_len, hidden_size, stride_batch_c, stride_batch_a, stride_batch_b, stride_seq_c, stride_seq_a, stride_seq_b, stream);
return status; return status;
} }
......
...@@ -39,11 +39,11 @@ public: ...@@ -39,11 +39,11 @@ public:
} }
return utils::Result<SwigluInfo>(SwigluInfo{ return utils::Result<SwigluInfo>(SwigluInfo{
c_desc->dtype(), c_desc->dtype(),
std::move(c_desc->shape()), c_desc->shape(),
ndim, ndim,
std::move(c_desc->strides()), c_desc->strides(),
std::move(a_desc->strides()), a_desc->strides(),
std::move(b_desc->strides()), b_desc->strides(),
}); });
} }
}; };
......
#include "../../../../../include/infinicore.h" #include "../../../devices/ascend/ascend_kernel_common.h"
#include "kernel_operator.h"
constexpr int32_t BLOCK_NUM = 8;
constexpr int32_t BUFFER_NUM = 2;
constexpr int32_t BYTE_ALIGN = 32;
// ubsize = 196KB
using namespace AscendC; using namespace AscendC;
...@@ -12,15 +6,15 @@ template <typename T> ...@@ -12,15 +6,15 @@ template <typename T>
class SwigluKernel { class SwigluKernel {
public: public:
__aicore__ inline SwigluKernel() {} __aicore__ inline SwigluKernel() {}
__aicore__ inline void Init(GM_ADDR c, GM_ADDR a, GM_ADDR b, int batch_, int seq, int hd, __aicore__ inline void Init(GM_ADDR c, GM_ADDR a, GM_ADDR b, int64_t batch_, int64_t seq, int64_t hd,
int stride_batch_c, int stride_batch_a, int stride_batch_b, int64_t stride_batch_c, int64_t stride_batch_a, int64_t stride_batch_b,
int stride_seq_c, int stride_seq_a, int stride_seq_b); int64_t stride_seq_c, int64_t stride_seq_a, int64_t stride_seq_b);
__aicore__ inline void Process(); __aicore__ inline void Process();
private: private:
__aicore__ inline void CopyIn(int32_t i); __aicore__ inline void CopyIn(int64_t i);
__aicore__ inline void Compute(int32_t i); __aicore__ inline void Compute(int64_t i);
__aicore__ inline void CopyOut(int32_t i); __aicore__ inline void CopyOut(int64_t i);
private: private:
GlobalTensor<T> cGm, aGm, bGm; GlobalTensor<T> cGm, aGm, bGm;
...@@ -28,18 +22,17 @@ private: ...@@ -28,18 +22,17 @@ private:
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueC; TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueC;
TPipe pipe; TPipe pipe;
uint32_t _data_size = 0;
float _beta_value = 1.0f; float _beta_value = 1.0f;
uint32_t _block_idx, _tile_len, _copy_len; int64_t _block_idx, _tile_len, _copy_len,
uint32_t batch, seq_len, hidden_size; batch, seq_len, hidden_size,
int32_t strideBatchA = 1, strideBatchB = 1, strideBatchC = 1; strideSeqA, strideSeqB, strideSeqC;
int32_t strideSeqA, strideSeqB, strideSeqC; int64_t strideBatchA = 1, strideBatchB = 1, strideBatchC = 1;
}; };
template <typename T> template <typename T>
__aicore__ inline void SwigluKernel<T>::Init(GM_ADDR c, GM_ADDR a, GM_ADDR b, int batch_, int seq, int hd, __aicore__ inline void SwigluKernel<T>::Init(GM_ADDR c, GM_ADDR a, GM_ADDR b, int64_t batch_, int64_t seq, int64_t hd,
int stride_batch_c, int stride_batch_a, int stride_batch_b, int64_t stride_batch_c, int64_t stride_batch_a, int64_t stride_batch_b,
int stride_seq_c, int stride_seq_a, int stride_seq_b) { int64_t stride_seq_c, int64_t stride_seq_a, int64_t stride_seq_b) {
// Init Shape & StrideVariables // Init Shape & StrideVariables
batch = batch_; batch = batch_;
seq_len = seq; seq_len = seq;
...@@ -67,7 +60,7 @@ __aicore__ inline void SwigluKernel<T>::Init(GM_ADDR c, GM_ADDR a, GM_ADDR b, in ...@@ -67,7 +60,7 @@ __aicore__ inline void SwigluKernel<T>::Init(GM_ADDR c, GM_ADDR a, GM_ADDR b, in
} }
template <typename T> template <typename T>
__aicore__ inline void SwigluKernel<T>::CopyIn(int32_t i) { __aicore__ inline void SwigluKernel<T>::CopyIn(int64_t i) {
// Alloc tensor from queue memory // Alloc tensor from queue memory
LocalTensor<T> aLocal = inQueueA.AllocTensor<T>(); LocalTensor<T> aLocal = inQueueA.AllocTensor<T>();
LocalTensor<T> bLocal = inQueueB.AllocTensor<T>(); LocalTensor<T> bLocal = inQueueB.AllocTensor<T>();
...@@ -75,8 +68,8 @@ __aicore__ inline void SwigluKernel<T>::CopyIn(int32_t i) { ...@@ -75,8 +68,8 @@ __aicore__ inline void SwigluKernel<T>::CopyIn(int32_t i) {
auto batchIdx = batch == 1 ? 0 : i / seq_len; auto batchIdx = batch == 1 ? 0 : i / seq_len;
auto seqIdx = batch == 1 ? i : i % seq_len; auto seqIdx = batch == 1 ? i : i % seq_len;
int32_t idxa = batchIdx * strideBatchA + seqIdx * strideSeqA + _block_idx * _tile_len; int64_t idxa = batchIdx * strideBatchA + seqIdx * strideSeqA + _block_idx * _tile_len;
int32_t idxb = batchIdx * strideBatchB + seqIdx * strideSeqB + _block_idx * _tile_len; int64_t idxb = batchIdx * strideBatchB + seqIdx * strideSeqB + _block_idx * _tile_len;
// Copy process_th tile from global tensor to local tensor // Copy process_th tile from global tensor to local tensor
DataCopy(aLocal, aGm[idxa], _copy_len); DataCopy(aLocal, aGm[idxa], _copy_len);
DataCopy(bLocal, bGm[idxb], _copy_len); DataCopy(bLocal, bGm[idxb], _copy_len);
...@@ -87,7 +80,7 @@ __aicore__ inline void SwigluKernel<T>::CopyIn(int32_t i) { ...@@ -87,7 +80,7 @@ __aicore__ inline void SwigluKernel<T>::CopyIn(int32_t i) {
} }
template <typename T> template <typename T>
__aicore__ inline void SwigluKernel<T>::Compute(int32_t i) { __aicore__ inline void SwigluKernel<T>::Compute(int64_t i) {
// Deque input tensors from VECIN queue // Deque input tensors from VECIN queue
LocalTensor<T> aLocal = inQueueA.DeQue<T>(); LocalTensor<T> aLocal = inQueueA.DeQue<T>();
LocalTensor<T> bLocal = inQueueB.DeQue<T>(); LocalTensor<T> bLocal = inQueueB.DeQue<T>();
...@@ -101,12 +94,12 @@ __aicore__ inline void SwigluKernel<T>::Compute(int32_t i) { ...@@ -101,12 +94,12 @@ __aicore__ inline void SwigluKernel<T>::Compute(int32_t i) {
} }
template <typename T> template <typename T>
__aicore__ inline void SwigluKernel<T>::CopyOut(int32_t i) { __aicore__ inline void SwigluKernel<T>::CopyOut(int64_t i) {
// Deque output tensor from VECOUT queue // Deque output tensor from VECOUT queue
LocalTensor<T> cLocal = outQueueC.DeQue<T>(); LocalTensor<T> cLocal = outQueueC.DeQue<T>();
auto batchIdx = batch == 1 ? 0 : i / seq_len; auto batchIdx = batch == 1 ? 0 : i / seq_len;
auto seqIdx = batch == 1 ? i : i % seq_len; auto seqIdx = batch == 1 ? i : i % seq_len;
int32_t idxc = batchIdx * strideBatchC + seqIdx * strideSeqC + _block_idx * _tile_len; int64_t idxc = batchIdx * strideBatchC + seqIdx * strideSeqC + _block_idx * _tile_len;
// Copy progress_th tile from local tensor to global tensor // Copy progress_th tile from local tensor to global tensor
if (_tile_len * sizeof(T) % BYTE_ALIGN != 0) { if (_tile_len * sizeof(T) % BYTE_ALIGN != 0) {
DataCopyExtParams dcep = {1, static_cast<uint32_t>(_tile_len * sizeof(T)), 0, 0, 0}; DataCopyExtParams dcep = {1, static_cast<uint32_t>(_tile_len * sizeof(T)), 0, 0, 0};
...@@ -120,7 +113,7 @@ __aicore__ inline void SwigluKernel<T>::CopyOut(int32_t i) { ...@@ -120,7 +113,7 @@ __aicore__ inline void SwigluKernel<T>::CopyOut(int32_t i) {
template <typename T> template <typename T>
__aicore__ inline void SwigluKernel<T>::Process() { __aicore__ inline void SwigluKernel<T>::Process() {
for (int32_t i = 0; i < batch * seq_len; ++i) { for (int64_t i = 0; i < batch * seq_len; ++i) {
CopyIn(i); CopyIn(i);
Compute(i); Compute(i);
CopyOut(i); CopyOut(i);
...@@ -128,37 +121,52 @@ __aicore__ inline void SwigluKernel<T>::Process() { ...@@ -128,37 +121,52 @@ __aicore__ inline void SwigluKernel<T>::Process() {
} }
__global__ __aicore__ void swiglu_kernel_half(GM_ADDR c, GM_ADDR a, GM_ADDR b, __global__ __aicore__ void swiglu_kernel_half(GM_ADDR c, GM_ADDR a, GM_ADDR b,
int batch, int seq, int hd, int64_t batch, int64_t seq, int64_t hd,
int stride_batch_c, int stride_batch_a, int stride_batch_b, int64_t stride_batch_c, int64_t stride_batch_a, int64_t stride_batch_b,
int stride_seq_c, int stride_seq_a, int stride_seq_b) { int64_t stride_seq_c, int64_t stride_seq_a, int64_t stride_seq_b) {
SwigluKernel<half> op; SwigluKernel<half> op;
op.Init(c, a, b, batch, seq, hd, stride_batch_c, stride_batch_a, stride_batch_b, stride_seq_c, stride_seq_a, stride_seq_b); op.Init(c, a, b,
batch, seq, hd,
stride_batch_c, stride_batch_a, stride_batch_b,
stride_seq_c, stride_seq_a, stride_seq_b);
op.Process(); op.Process();
} }
__global__ __aicore__ void swiglu_kernel_float(GM_ADDR c, GM_ADDR a, GM_ADDR b, __global__ __aicore__ void swiglu_kernel_float(GM_ADDR c, GM_ADDR a, GM_ADDR b,
int batch, int seq, int hd, int64_t batch, int64_t seq, int64_t hd,
int stride_batch_c, int stride_batch_a, int stride_batch_b, int64_t stride_batch_c, int64_t stride_batch_a, int64_t stride_batch_b,
int stride_seq_c, int stride_seq_a, int stride_seq_b) { int64_t stride_seq_c, int64_t stride_seq_a, int64_t stride_seq_b) {
SwigluKernel<float> op; SwigluKernel<float> op;
op.Init(c, a, b, batch, seq, hd, stride_batch_c, stride_batch_a, stride_batch_b, stride_seq_c, stride_seq_a, stride_seq_b); op.Init(c, a, b,
batch, seq, hd,
stride_batch_c, stride_batch_a, stride_batch_b,
stride_seq_c, stride_seq_a, stride_seq_b);
op.Process(); op.Process();
} }
extern "C" infiniStatus_t swiglu_kernel_launch( extern "C" infiniStatus_t swiglu_kernel_launch(
void *c, void *a, void *b, void *c, void *a, void *b,
int dtype, int batch, int seq, int hd, infiniDtype_t dtype, size_t batch, size_t seq, size_t hd,
int stride_batch_c, int stride_batch_a, int stride_batch_b, ptrdiff_t stride_batch_c, ptrdiff_t stride_batch_a, ptrdiff_t stride_batch_b,
int stride_seq_c, int stride_seq_a, int stride_seq_b, void *stream) { ptrdiff_t stride_seq_c, ptrdiff_t stride_seq_a, ptrdiff_t stride_seq_b, void *stream) {
#define LAUNCH_SWIGLU_KERNEL(DTYPE_ENUM, KERNEL_NAME) \
case DTYPE_ENUM: \
KERNEL_NAME<<<BLOCK_NUM, nullptr, stream>>>( \
c, a, b, \
static_cast<int64_t>(batch), \
static_cast<int64_t>(seq), \
static_cast<int64_t>(hd), \
stride_batch_c, stride_batch_a, stride_batch_b, \
stride_seq_c, stride_seq_a, stride_seq_b); \
return INFINI_STATUS_SUCCESS;
switch (dtype) { switch (dtype) {
case 12: LAUNCH_SWIGLU_KERNEL(INFINI_DTYPE_F16, swiglu_kernel_half)
swiglu_kernel_half<<<BLOCK_NUM, nullptr, stream>>>(c, a, b, batch, seq, hd, stride_batch_c, stride_batch_a, stride_batch_b, stride_seq_c, stride_seq_a, stride_seq_b); LAUNCH_SWIGLU_KERNEL(INFINI_DTYPE_F32, swiglu_kernel_float)
return INFINI_STATUS_SUCCESS;
case 13:
swiglu_kernel_float<<<BLOCK_NUM, nullptr, stream>>>(c, a, b, batch, seq, hd, stride_batch_c, stride_batch_a, stride_batch_b, stride_seq_c, stride_seq_a, stride_seq_b);
return INFINI_STATUS_SUCCESS;
default: default:
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
#undef LAUNCH_SWIGLU_KERNEL
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment