Commit b4302732 authored by zhangyue's avatar zhangyue
Browse files

issue/174: Separate getworkspace for rmsnorm

parent 02922ce9
...@@ -10,12 +10,15 @@ struct Descriptor::Opaque { ...@@ -10,12 +10,15 @@ struct Descriptor::Opaque {
aclnnTensorDescriptor_t w; aclnnTensorDescriptor_t w;
aclnnTensorDescriptor_t rstd; aclnnTensorDescriptor_t rstd;
size_t workspaceSize; size_t workspaceSize;
aclOpExecutor *executor;
~Opaque() { ~Opaque() {
delete y; delete y;
delete x; delete x;
delete w; delete w;
delete rstd; delete rstd;
aclDestroyAclOpExecutor(executor);
} }
}; };
...@@ -62,17 +65,16 @@ infiniStatus_t Descriptor::create( ...@@ -62,17 +65,16 @@ infiniStatus_t Descriptor::create(
// Get WorkspaceSize and set executor // Get WorkspaceSize and set executor
CHECK_ACL(aclnnRmsNormGetWorkspaceSize(tx, tw, static_cast<double>(epsilon), ty, trstd, &workspace_size, &executor)); CHECK_ACL(aclnnRmsNormGetWorkspaceSize(tx, tw, static_cast<double>(epsilon), ty, trstd, &workspace_size, &executor));
aclSetAclOpExecutorRepeatable(executor);
auto handle_ascend = reinterpret_cast<device::ascend::Handle *>(handle); auto handle_ascend = reinterpret_cast<device::ascend::Handle *>(handle);
size_t all_workspace_size = workspace_size + rstd->numel() * aclDataTypeSize(rstd->dataType); size_t all_workspace_size = workspace_size + rstd->numel() * aclDataTypeSize(rstd->dataType);
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
new Opaque{y, x, w, rstd, workspace_size}, new Opaque{y, x, w, rstd, workspace_size, executor},
std::move(info), std::move(info),
all_workspace_size, all_workspace_size,
handle_ascend->device, handle_ascend->device_id); handle_ascend->device, handle_ascend->device_id);
aclDestroyAclOpExecutor(executor);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
...@@ -88,21 +90,16 @@ infiniStatus_t Descriptor::calculate( ...@@ -88,21 +90,16 @@ infiniStatus_t Descriptor::calculate(
auto tx = _opaque->x->tensor; auto tx = _opaque->x->tensor;
auto ty = _opaque->y->tensor; auto ty = _opaque->y->tensor;
auto trstd = _opaque->rstd->tensor; auto trstd = _opaque->rstd->tensor;
size_t workspace_size_ = 0;
aclOpExecutor *executor = nullptr;
CHECK_ACL(aclnnRmsNormGetWorkspaceSize(tx, tw, static_cast<double>(_info.epsilon), ty, trstd, &workspace_size_, &executor));
CHECK_ACL(aclSetAclOpExecutorRepeatable(executor));
void *rstdPtr = (void *)((uint8_t *)workspace + _opaque->workspaceSize); void *rstdPtr = (void *)((uint8_t *)workspace + _opaque->workspaceSize);
auto unit = infiniSizeOf(_info.atype); auto unit = infiniSizeOf(_info.atype);
AclSetTensorAddr(executor, 1, tw, (void *)w); AclSetTensorAddr(_opaque->executor, 1, tw, (void *)w);
AclSetTensorAddr(executor, 3, trstd, rstdPtr); AclSetTensorAddr(_opaque->executor, 3, trstd, rstdPtr);
for (size_t i = 0; i < (_info.shape)[0]; ++i) { for (size_t i = 0; i < (_info.shape)[0]; ++i) {
AclSetTensorAddr(executor, 0, tx, ((char *)x) + i * (_info.x_strides)[0] * unit); AclSetTensorAddr(_opaque->executor, 0, tx, ((char *)x) + i * (_info.x_strides)[0] * unit);
AclSetTensorAddr(executor, 2, ty, ((char *)y) + i * (_info.y_strides)[0] * unit); AclSetTensorAddr(_opaque->executor, 2, ty, ((char *)y) + i * (_info.y_strides)[0] * unit);
CHECK_ACL(aclnnRmsNorm(workspace, _opaque->workspaceSize, executor, stream)); CHECK_ACL(aclnnRmsNorm(workspace, _opaque->workspaceSize, _opaque->executor, stream));
} }
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment