Unverified Commit beaf1e8c authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #59 from PanZezhong1725/issue/7

issue/7: rmsnorm算子-昇腾
parents 65df17f7 527086e9
......@@ -8,6 +8,10 @@ std::vector<int64_t> inferStorageShape(std::vector<int64_t> shape, std::vector<i
return storageShape;
}
size_t aclnnTensorDescriptor::numel() const {
return std::accumulate(shape.begin(), shape.end(), (size_t)1, std::multiplies<size_t>());
}
aclnnTensorDescriptor::aclnnTensorDescriptor(infiniopTensorDescriptor_t desc, void *data) {
this->ndim = desc->ndim();
this->shape = std::vector<int64_t>(ndim);
......
......@@ -34,10 +34,10 @@ struct aclnnTensorDescriptor {
int64_t storageNdim = 1;
aclTensor *tensor;
// aclnnGemmGetWorkspaceSize only support 2D matrix multiply, so we need to convert 3D tensor to 2D tensor
aclnnTensorDescriptor(aclDataType dtype, const std::vector<int64_t> &shape, const std::vector<int64_t> &strides, void *data = nullptr);
aclnnTensorDescriptor(infiniopTensorDescriptor_t y_desc, void *data = nullptr);
~aclnnTensorDescriptor();
size_t numel() const;
std::string toString();
};
......
#include "rms_norm_aclnn.h"
#include "../../../devices/ascend/common_ascend.h"
#include <aclnnop/aclnn_rms_norm.h>
namespace op::rms_norm::ascend {
struct Descriptor::Opaque {
mutable aclOpExecutor *executor;
aclnnTensorDescriptor_t y;
aclnnTensorDescriptor_t x;
aclnnTensorDescriptor_t w;
aclnnTensorDescriptor_t rstd;
size_t workspaceSize;
~Opaque() {
delete y;
delete x;
delete w;
delete rstd;
aclDestroyAclOpExecutor(executor);
}
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc,
float epsilon) {
RMSNormInfo info;
auto handle_ascend = reinterpret_cast<device::ascend::Handle *>(handle);
CHECK_STATUS(createRMSNormInfo(&info, y_desc, x_desc, w_desc, epsilon));
size_t workspace_size = 0;
aclOpExecutor *executor = nullptr;
aclnnTensorDescriptor_t y = nullptr;
aclnnTensorDescriptor_t x = nullptr;
aclnnTensorDescriptor_t w = nullptr;
aclnnTensorDescriptor_t rstd = nullptr;
std::vector<int64_t> slice_shape = {static_cast<int64_t>((info.shape)[1])};
auto slice_stride = std::vector<int64_t>(1, 1);
y = new aclnnTensorDescriptor(toAclDataType(info.atype), slice_shape, slice_stride);
x = new aclnnTensorDescriptor(toAclDataType(info.atype), slice_shape, slice_stride);
w = new aclnnTensorDescriptor(w_desc);
// Get AclTensor
aclTensor *ty = y->tensor;
aclTensor *tx = x->tensor;
aclTensor *tw = w->tensor;
// Set rstdDesc
// See: https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC3alpha002/apiref/appdevgapi/context/aclnnRmsNorm.md
// rstdTensor cannot set nullptr in aclnn
auto rstd_shape = std::vector<int64_t>(1, 1);
auto rstd_strides = std::vector<int64_t>(1, 1);
rstd = new aclnnTensorDescriptor(toAclDataType(INFINI_DTYPE_F32), rstd_shape, rstd_strides);
aclTensor *trstd = rstd->tensor;
// Get WorkspaceSize and set executor
CHECK_ACL(aclnnRmsNormGetWorkspaceSize(tx, tw, static_cast<double>(epsilon), ty, trstd, &workspace_size, &executor));
aclSetAclOpExecutorRepeatable(executor);
size_t allWorkspaceSize = workspace_size + rstd->numel() * aclDataTypeSize(rstd->dataType);
*desc_ptr = new Descriptor(new Opaque{executor, y, x, w, rstd, workspace_size}, info, allWorkspaceSize, handle_ascend->device, handle_ascend->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, void *y,
const void *x, const void *w, void *stream) {
if (workspace_size < workspaceSize()) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
auto tw = _opaque->w->tensor;
auto tx = _opaque->x->tensor;
auto ty = _opaque->y->tensor;
auto trstd = _opaque->rstd->tensor;
void *rstdPtr = (void *)((uint8_t *)workspace + _opaque->workspaceSize);
auto unit = infiniSizeOf(_info.atype);
AclSetTensorAddr(_opaque->executor, 1, tw, (void *)w);
AclSetTensorAddr(_opaque->executor, 3, trstd, rstdPtr);
for (size_t i = 0; i < (_info.shape)[0]; ++i) {
AclSetTensorAddr(_opaque->executor, 0, tx, ((char *)x) + i * (_info.x_strides)[0] * unit);
AclSetTensorAddr(_opaque->executor, 2, ty, ((char *)y) + i * (_info.y_strides)[0] * unit);
CHECK_ACL(aclnnRmsNorm(workspace, _opaque->workspaceSize, _opaque->executor, stream));
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::rms_norm::ascend
#ifndef __ACLNN_RMS_NORM_H__
#define __ACLNN_RMS_NORM_H__
#include "../rms_norm.h"
DESCRIPTOR(ascend)
#endif
......@@ -8,6 +8,9 @@
#ifdef ENABLE_CUDA_API
#include "cuda/rms_norm_cuda.cuh"
#endif
#ifdef ENABLE_ASCEND_API
#include "ascend/rms_norm_aclnn.h"
#endif
__C infiniStatus_t infiniopCreateRMSNormDescriptor(
infiniopHandle_t handle,
......@@ -39,15 +42,8 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
return bangCreateRMSNormDescriptor((BangHandle_t)handle, (RMSNormBangDescriptor_t *)desc_ptr, y_desc, x_desc, w_desc, epsilon);
}
#endif
#ifdef ENABLE_ASCEND_NPU
case DevAscendNpu: {
return aclnnCreateRMSNormDescriptor((AscendHandle_t)handle,
(RMSNormAclnnDescriptor_t *)desc_ptr,
y_desc,
x_desc,
w_desc,
epsilon);
}
#ifdef ENABLE_ASCEND_API
CREATE(INFINI_DEVICE_ASCEND, ascend)
#endif
#ifdef ENABLE_METAX_GPU
case DevMetaxGpu: {
......@@ -85,11 +81,8 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d
return bangGetRMSNormWorkspaceSize((RMSNormBangDescriptor_t)desc, size);
}
#endif
#ifdef ENABLE_ASCEND_NPU
case DevAscendNpu: {
return aclnnGetRMSNormWorkspaceSize((RMSNormAclnnDescriptor_t)desc,
size);
}
#ifdef ENABLE_ASCEND_API
GET(INFINI_DEVICE_ASCEND, ascend)
#endif
#ifdef ENABLE_METAX_GPU
case DevMetaxGpu: {
......@@ -128,16 +121,8 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works
return bangRMSNorm((RMSNormBangDescriptor_t)desc, workspace, workspace_size, y, x, w, stream);
}
#endif
#ifdef ENABLE_ASCEND_NPU
case DevAscendNpu: {
return aclnnRMSNorm((RMSNormAclnnDescriptor_t)desc,
workspace,
workspace_size,
y,
x,
w,
stream);
}
#ifdef ENABLE_ASCEND_API
CALCULATE(INFINI_DEVICE_ASCEND, ascend)
#endif
#ifdef ENABLE_METAX_GPU
case DevMetaxGpu: {
......@@ -175,10 +160,8 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t
return bangDestroyRMSNormDescriptor((RMSNormBangDescriptor_t)desc);
}
#endif
#ifdef ENABLE_ASCEND_NPU
case DevAscendNpu: {
return aclnnDestroyRMSNormDescriptor((RMSNormAclnnDescriptor_t)desc);
}
#ifdef ENABLE_ASCEND_API
DESTROY(INFINI_DEVICE_ASCEND, ascend)
#endif
#ifdef ENABLE_METAX_GPU
case DevMetaxGpu: {
......
......@@ -55,6 +55,10 @@ inline infiniStatus_t createRMSNormInfo(RMSNormInfo *info, infiniopTensorDescrip
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
if (x_desc->stride(1) != 1 || y_desc->stride(1) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
info->shape = std::move(y_desc->shape());
info->y_strides = std::move(y_desc->strides());
info->x_strides = std::move(x_desc->strides());
......
add_defines("ENABLE_ASCEND_API")
local ASCEND_HOME = os.getenv("ASCEND_HOME")
local ASCEND_HOME = os.getenv("ASCEND_HOME") or os.getenv("ASCEND_TOOLKIT_HOME")
local SOC_VERSION = os.getenv("SOC_VERSION")
-- Add include dirs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment