Commit 527086e9 authored by zhangyunze's avatar zhangyunze
Browse files

fix: 依据rmsnorm昇腾算子官方更新,删除cast操作

parent d307db53
...@@ -8,7 +8,7 @@ std::vector<int64_t> inferStorageShape(std::vector<int64_t> shape, std::vector<i ...@@ -8,7 +8,7 @@ std::vector<int64_t> inferStorageShape(std::vector<int64_t> shape, std::vector<i
return storageShape; return storageShape;
} }
size_t aclnnTensorDescriptor::size() const { size_t aclnnTensorDescriptor::numel() const {
return std::accumulate(shape.begin(), shape.end(), (size_t)1, std::multiplies<size_t>()); return std::accumulate(shape.begin(), shape.end(), (size_t)1, std::multiplies<size_t>());
} }
......
...@@ -37,7 +37,7 @@ struct aclnnTensorDescriptor { ...@@ -37,7 +37,7 @@ struct aclnnTensorDescriptor {
aclnnTensorDescriptor(aclDataType dtype, const std::vector<int64_t> &shape, const std::vector<int64_t> &strides, void *data = nullptr); aclnnTensorDescriptor(aclDataType dtype, const std::vector<int64_t> &shape, const std::vector<int64_t> &strides, void *data = nullptr);
aclnnTensorDescriptor(infiniopTensorDescriptor_t y_desc, void *data = nullptr); aclnnTensorDescriptor(infiniopTensorDescriptor_t y_desc, void *data = nullptr);
~aclnnTensorDescriptor(); ~aclnnTensorDescriptor();
size_t size() const; size_t numel() const;
std::string toString(); std::string toString();
}; };
......
#include "rms_norm_aclnn.h" #include "rms_norm_aclnn.h"
#include "../../../devices/ascend/common_ascend.h" #include "../../../devices/ascend/common_ascend.h"
#include <aclnnop/aclnn_cast.h>
#include <aclnnop/aclnn_rms_norm.h> #include <aclnnop/aclnn_rms_norm.h>
namespace op::rms_norm::ascend { namespace op::rms_norm::ascend {
struct Descriptor::Opaque { struct Descriptor::Opaque {
mutable aclOpExecutor *executor; mutable aclOpExecutor *executor;
mutable aclOpExecutor *castExecutor;
aclnnTensorDescriptor_t y; aclnnTensorDescriptor_t y;
aclnnTensorDescriptor_t x; aclnnTensorDescriptor_t x;
aclnnTensorDescriptor_t w; aclnnTensorDescriptor_t w;
aclnnTensorDescriptor_t rstd; aclnnTensorDescriptor_t rstd;
aclnnTensorDescriptor_t cast;
size_t workspaceSize; size_t workspaceSize;
size_t castWorkspaceSize;
~Opaque() { ~Opaque() {
delete y; delete y;
delete x; delete x;
delete w; delete w;
delete rstd; delete rstd;
delete cast;
aclDestroyAclOpExecutor(executor); aclDestroyAclOpExecutor(executor);
aclDestroyAclOpExecutor(castExecutor);
} }
}; };
...@@ -42,17 +36,15 @@ infiniStatus_t Descriptor::create( ...@@ -42,17 +36,15 @@ infiniStatus_t Descriptor::create(
auto handle_ascend = reinterpret_cast<device::ascend::Handle *>(handle); auto handle_ascend = reinterpret_cast<device::ascend::Handle *>(handle);
CHECK_STATUS(createRMSNormInfo(&info, y_desc, x_desc, w_desc, epsilon)); CHECK_STATUS(createRMSNormInfo(&info, y_desc, x_desc, w_desc, epsilon));
size_t workspace_size, cast_workspace_size = 0; size_t workspace_size = 0;
aclOpExecutor *executor = nullptr; aclOpExecutor *executor = nullptr;
aclOpExecutor *castExecutor = nullptr;
aclnnTensorDescriptor_t y = nullptr; aclnnTensorDescriptor_t y = nullptr;
aclnnTensorDescriptor_t x = nullptr; aclnnTensorDescriptor_t x = nullptr;
aclnnTensorDescriptor_t w = nullptr; aclnnTensorDescriptor_t w = nullptr;
aclnnTensorDescriptor_t rstd = nullptr; aclnnTensorDescriptor_t rstd = nullptr;
aclnnTensorDescriptor_t cast = nullptr;
std::vector<int64_t> slice_shape = {1, static_cast<int64_t>((info.shape)[1])}; std::vector<int64_t> slice_shape = {static_cast<int64_t>((info.shape)[1])};
auto slice_stride = std::vector<int64_t>(2, 1); auto slice_stride = std::vector<int64_t>(1, 1);
y = new aclnnTensorDescriptor(toAclDataType(info.atype), slice_shape, slice_stride); y = new aclnnTensorDescriptor(toAclDataType(info.atype), slice_shape, slice_stride);
x = new aclnnTensorDescriptor(toAclDataType(info.atype), slice_shape, slice_stride); x = new aclnnTensorDescriptor(toAclDataType(info.atype), slice_shape, slice_stride);
w = new aclnnTensorDescriptor(w_desc); w = new aclnnTensorDescriptor(w_desc);
...@@ -64,28 +56,17 @@ infiniStatus_t Descriptor::create( ...@@ -64,28 +56,17 @@ infiniStatus_t Descriptor::create(
// Set rstdDesc // Set rstdDesc
// See: https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC3alpha002/apiref/appdevgapi/context/aclnnRmsNorm.md // See: https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC3alpha002/apiref/appdevgapi/context/aclnnRmsNorm.md
// rstdTensor cannot set nullptr in aclnn // rstdTensor cannot set nullptr in aclnn
auto rstd_shape = std::vector<int64_t>(2, 1); auto rstd_shape = std::vector<int64_t>(1, 1);
auto rstd_strides = std::vector<int64_t>(2, 1); auto rstd_strides = std::vector<int64_t>(1, 1);
rstd = new aclnnTensorDescriptor(toAclDataType(INFINI_DTYPE_F32), rstd_shape, rstd_strides); rstd = new aclnnTensorDescriptor(toAclDataType(INFINI_DTYPE_F32), rstd_shape, rstd_strides);
aclTensor *trstd = rstd->tensor; aclTensor *trstd = rstd->tensor;
if (w->dataType != x->dataType) {
cast = new aclnnTensorDescriptor(x->dataType, w->shape, w->strides);
}
// Get WorkspaceSize and set executor // Get WorkspaceSize and set executor
CHECK_ACL(aclnnRmsNormGetWorkspaceSize(tx, tw, static_cast<double>(epsilon), ty, trstd, &workspace_size, &executor));
CHECK_ACL(aclnnRmsNormGetWorkspaceSize(tx, cast == nullptr ? tw : cast->tensor, static_cast<double>(epsilon), ty, trstd, &workspace_size, &executor));
aclSetAclOpExecutorRepeatable(executor); aclSetAclOpExecutorRepeatable(executor);
if (cast) {
aclTensor *tc = cast->tensor;
CHECK_ACL(aclnnCastGetWorkspaceSize(tw, cast->dataType, tc, &cast_workspace_size, &castExecutor));
aclSetAclOpExecutorRepeatable(castExecutor);
}
size_t allWorkspaceSize = workspace_size + cast_workspace_size + rstd->size() * aclDataTypeSize(rstd->dataType); size_t allWorkspaceSize = workspace_size + rstd->numel() * aclDataTypeSize(rstd->dataType);
allWorkspaceSize = allWorkspaceSize + (cast == nullptr ? 0 : cast->size() * aclDataTypeSize(cast->dataType)); *desc_ptr = new Descriptor(new Opaque{executor, y, x, w, rstd, workspace_size}, info, allWorkspaceSize, handle_ascend->device, handle_ascend->device_id);
*desc_ptr = new Descriptor(new Opaque{executor, castExecutor, y, x, w, rstd, cast, workspace_size, cast_workspace_size}, info, allWorkspaceSize, handle_ascend->device, handle_ascend->device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
...@@ -101,24 +82,13 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, voi ...@@ -101,24 +82,13 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, voi
auto trstd = _opaque->rstd->tensor; auto trstd = _opaque->rstd->tensor;
void *rstdPtr = (void *)((uint8_t *)workspace + _opaque->workspaceSize); void *rstdPtr = (void *)((uint8_t *)workspace + _opaque->workspaceSize);
void *castPtr = nullptr;
if (_opaque->cast) {
auto tcast = _opaque->cast->tensor;
castPtr = (void *)((float *)rstdPtr + _opaque->rstd->size());
AclSetTensorAddr(_opaque->castExecutor, 0, tw, (void *)w);
AclSetTensorAddr(_opaque->castExecutor, 1, tcast, castPtr);
CHECK_ACL(aclnnCast(nullptr, _opaque->castWorkspaceSize, _opaque->castExecutor, stream));
}
auto unit = infiniSizeOf(_info.atype); auto unit = infiniSizeOf(_info.atype);
AclSetTensorAddr(_opaque->executor, 1, tw, (void *)w);
AclSetTensorAddr(_opaque->executor, 3, trstd, rstdPtr);
for (size_t i = 0; i < (_info.shape)[0]; ++i) { for (size_t i = 0; i < (_info.shape)[0]; ++i) {
AclSetTensorAddr(_opaque->executor, 0, tx, ((char *)x) + i * (_info.x_strides)[0] * unit); AclSetTensorAddr(_opaque->executor, 0, tx, ((char *)x) + i * (_info.x_strides)[0] * unit);
if (_opaque->cast) {
AclSetTensorAddr(_opaque->executor, 1, _opaque->cast->tensor, castPtr);
} else {
AclSetTensorAddr(_opaque->executor, 1, tw, (void *)w);
}
AclSetTensorAddr(_opaque->executor, 2, ty, ((char *)y) + i * (_info.y_strides)[0] * unit); AclSetTensorAddr(_opaque->executor, 2, ty, ((char *)y) + i * (_info.y_strides)[0] * unit);
AclSetTensorAddr(_opaque->executor, 3, trstd, rstdPtr);
CHECK_ACL(aclnnRmsNorm(workspace, _opaque->workspaceSize, _opaque->executor, stream)); CHECK_ACL(aclnnRmsNorm(workspace, _opaque->workspaceSize, _opaque->executor, stream));
} }
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment