Commit b4302732 authored by zhangyue's avatar zhangyue
Browse files

issue/174: Separate getworkspace for rmsnorm

parent 02922ce9
......@@ -10,12 +10,15 @@ struct Descriptor::Opaque {
aclnnTensorDescriptor_t w;
aclnnTensorDescriptor_t rstd;
size_t workspaceSize;
aclOpExecutor *executor;
~Opaque() {
delete y;
delete x;
delete w;
delete rstd;
aclDestroyAclOpExecutor(executor);
}
};
......@@ -62,17 +65,16 @@ infiniStatus_t Descriptor::create(
// Get WorkspaceSize and set executor
CHECK_ACL(aclnnRmsNormGetWorkspaceSize(tx, tw, static_cast<double>(epsilon), ty, trstd, &workspace_size, &executor));
aclSetAclOpExecutorRepeatable(executor);
auto handle_ascend = reinterpret_cast<device::ascend::Handle *>(handle);
size_t all_workspace_size = workspace_size + rstd->numel() * aclDataTypeSize(rstd->dataType);
*desc_ptr = new Descriptor(
new Opaque{y, x, w, rstd, workspace_size},
new Opaque{y, x, w, rstd, workspace_size, executor},
std::move(info),
all_workspace_size,
handle_ascend->device, handle_ascend->device_id);
aclDestroyAclOpExecutor(executor);
return INFINI_STATUS_SUCCESS;
}
......@@ -88,21 +90,16 @@ infiniStatus_t Descriptor::calculate(
auto tx = _opaque->x->tensor;
auto ty = _opaque->y->tensor;
auto trstd = _opaque->rstd->tensor;
size_t workspace_size_ = 0;
aclOpExecutor *executor = nullptr;
CHECK_ACL(aclnnRmsNormGetWorkspaceSize(tx, tw, static_cast<double>(_info.epsilon), ty, trstd, &workspace_size_, &executor));
CHECK_ACL(aclSetAclOpExecutorRepeatable(executor));
void *rstdPtr = (void *)((uint8_t *)workspace + _opaque->workspaceSize);
auto unit = infiniSizeOf(_info.atype);
AclSetTensorAddr(executor, 1, tw, (void *)w);
AclSetTensorAddr(executor, 3, trstd, rstdPtr);
AclSetTensorAddr(_opaque->executor, 1, tw, (void *)w);
AclSetTensorAddr(_opaque->executor, 3, trstd, rstdPtr);
for (size_t i = 0; i < (_info.shape)[0]; ++i) {
AclSetTensorAddr(executor, 0, tx, ((char *)x) + i * (_info.x_strides)[0] * unit);
AclSetTensorAddr(executor, 2, ty, ((char *)y) + i * (_info.y_strides)[0] * unit);
CHECK_ACL(aclnnRmsNorm(workspace, _opaque->workspaceSize, executor, stream));
AclSetTensorAddr(_opaque->executor, 0, tx, ((char *)x) + i * (_info.x_strides)[0] * unit);
AclSetTensorAddr(_opaque->executor, 2, ty, ((char *)y) + i * (_info.y_strides)[0] * unit);
CHECK_ACL(aclnnRmsNorm(workspace, _opaque->workspaceSize, _opaque->executor, stream));
}
return INFINI_STATUS_SUCCESS;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment