Commit 8383c0e5 authored by zhangyue's avatar zhangyue
Browse files

issue/111: modify create RMSNormInfo

parent d854dbee
...@@ -22,8 +22,10 @@ infiniStatus_t Descriptor::create( ...@@ -22,8 +22,10 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t x_desc, infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc, infiniopTensorDescriptor_t w_desc,
float epsilon) { float epsilon) {
RMSNormInfo info; auto result = RMSNormInfo::create(y_desc, x_desc, w_desc, epsilon);
CHECK_STATUS(createRMSNormInfo(&info, y_desc, x_desc, w_desc, epsilon)); CHECK_RESULT(result);
auto info = result.take();
if (info.x_strides[1] != 1 || info.y_strides[1] != 1) { if (info.x_strides[1] != 1 || info.y_strides[1] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES; return INFINI_STATUS_BAD_TENSOR_STRIDES;
...@@ -59,7 +61,7 @@ infiniStatus_t launchKernel( ...@@ -59,7 +61,7 @@ infiniStatus_t launchKernel(
} }
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *y, const void *x, const void *w, void *stream) { void *y, const void *x, const void *w, void *stream) const {
if (workspace_size < _workspace_size) { if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE; return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment