Commit 74ffbff5 authored by YdrMaster's avatar YdrMaster
Browse files

issue/121/refactor: 用 Result 修改 RmsNormInfo 构造流程,并与 Gemm 保持风格一致


Signed-off-by: default avatarYdrMaster <ydrml@hotmail.com>
parent 006fb46e
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "../../../reduce/cpu/reduce.h" #include "../../../reduce/cpu/reduce.h"
namespace op::rms_norm::cpu { namespace op::rms_norm::cpu {
Descriptor::~Descriptor() {} Descriptor::~Descriptor() {}
infiniStatus_t Descriptor::create( infiniStatus_t Descriptor::create(
...@@ -12,9 +13,9 @@ infiniStatus_t Descriptor::create( ...@@ -12,9 +13,9 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t x_desc, infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc, infiniopTensorDescriptor_t w_desc,
float epsilon) { float epsilon) {
RMSNormInfo info; auto result = RMSNormInfo::create(y_desc, x_desc, w_desc, epsilon);
CHECK_STATUS(createRMSNormInfo(&info, y_desc, x_desc, w_desc, epsilon)); CHECK_RESULT(result);
*desc_ptr = new Descriptor(nullptr, info, 0, handle->device, handle->device_id); *desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
...@@ -68,9 +69,10 @@ infiniStatus_t rmsnormF16(const RMSNormInfo *info, fp16_t *y, const fp16_t *x, c ...@@ -68,9 +69,10 @@ infiniStatus_t rmsnormF16(const RMSNormInfo *info, fp16_t *y, const fp16_t *x, c
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, const void *x, const void *w, void *y, const void *x, const void *w,
void *stream) { void *stream) const {
if (_info.atype == INFINI_DTYPE_F16) { if (_info.atype == INFINI_DTYPE_F16) {
if (_info.wtype == INFINI_DTYPE_F16) { if (_info.wtype == INFINI_DTYPE_F16) {
CHECK_STATUS(rmsnormF16(&_info, (fp16_t *)y, (const fp16_t *)x, (const fp16_t *)w)); CHECK_STATUS(rmsnormF16(&_info, (fp16_t *)y, (const fp16_t *)x, (const fp16_t *)w));
......
#ifndef __RMS_NORM_INFO_H__
#define __RMS_NORM_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <vector>
namespace op::rms_norm {
class RMSNormInfo {
RMSNormInfo() = default;
public:
infiniDtype_t wtype;
infiniDtype_t atype;
float epsilon;
std::vector<size_t> shape;
std::vector<ptrdiff_t> y_strides;
std::vector<ptrdiff_t> x_strides;
size_t ndim() const { return shape.size(); }
size_t dim() const { return shape[ndim() - 1]; }
static utils::Result<RMSNormInfo> create(
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc,
float epsilon) {
auto atype = y_desc->dtype();
auto wtype = w_desc->dtype();
if (x_desc->dtype() != atype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (atype == INFINI_DTYPE_F16) {
if (wtype != INFINI_DTYPE_F16 && wtype != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (atype == INFINI_DTYPE_F32 || atype == INFINI_DTYPE_F64) {
if (atype != wtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (y_desc->ndim() != 2 || x_desc->ndim() != 2 || w_desc->ndim() != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t batch = y_desc->shape()[0];
size_t dim = y_desc->shape()[1];
if (x_desc->shape()[0] != batch || x_desc->shape()[1] != dim || w_desc->shape()[0] != dim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (w_desc->stride(0) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
if (x_desc->stride(1) != 1 || y_desc->stride(1) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
return utils::Result<RMSNormInfo>(RMSNormInfo{
wtype,
atype,
epsilon,
y_desc->shape(),
y_desc->strides(),
x_desc->strides(),
});
}
};
} // namespace op::rms_norm
#endif // __RMS_NORM_INFO_H__
#ifndef RMS_NORM_H #ifndef RMS_NORM_H
#define RMS_NORM_H #define RMS_NORM_H
#include "../../operator.h"
#include "../../tensor.h"
#include <vector>
struct RMSNormInfo {
infiniDtype_t wtype;
infiniDtype_t atype;
float epsilon;
std::vector<size_t> shape;
std::vector<ptrdiff_t> y_strides;
std::vector<ptrdiff_t> x_strides;
size_t ndim() { return shape.size(); }
size_t dim() { return shape[ndim() - 1]; }
};
inline infiniStatus_t createRMSNormInfo(RMSNormInfo *info, infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc,
float epsilon) {
auto atype = y_desc->dtype();
auto wtype = w_desc->dtype();
if (x_desc->dtype() != atype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (atype == INFINI_DTYPE_F16) {
if (wtype != INFINI_DTYPE_F16 && wtype != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (atype == INFINI_DTYPE_F32 || atype == INFINI_DTYPE_F64) {
if (atype != wtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
info->wtype = wtype;
info->atype = atype;
info->epsilon = epsilon;
if (y_desc->ndim() != 2 || x_desc->ndim() != 2 || w_desc->ndim() != 1) { #include "../../operator.h"
return INFINI_STATUS_BAD_TENSOR_SHAPE; #include "info.h"
}
size_t batch = y_desc->shape()[0];
size_t dim = y_desc->shape()[1];
if (x_desc->shape()[0] != batch || x_desc->shape()[1] != dim || w_desc->shape()[0] != dim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (w_desc->stride(0) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
if (x_desc->stride(1) != 1 || y_desc->stride(1) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
info->shape = std::move(y_desc->shape());
info->y_strides = std::move(y_desc->strides());
info->x_strides = std::move(x_desc->strides());
return INFINI_STATUS_SUCCESS;
}
#define DESCRIPTOR(NAMESPACE) \ #define DESCRIPTOR(NAMESPACE) \
\
namespace op::rms_norm::NAMESPACE { \ namespace op::rms_norm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \ class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \ struct Opaque; \
...@@ -79,14 +18,17 @@ inline infiniStatus_t createRMSNormInfo(RMSNormInfo *info, infiniopTensorDescrip ...@@ -79,14 +18,17 @@ inline infiniStatus_t createRMSNormInfo(RMSNormInfo *info, infiniopTensorDescrip
RMSNormInfo info, \ RMSNormInfo info, \
size_t workspace_size, \ size_t workspace_size, \
infiniDevice_t device_type, \ infiniDevice_t device_type, \
int device_id) : InfiniopDescriptor{device_type, device_id}, \ int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \ _opaque(opaque), \
_info(info), \ _info(info), \
_workspace_size(workspace_size) {} \ _workspace_size(workspace_size) {} \
\ \
public: \ public: \
~Descriptor(); \ ~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \ size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \ static infiniStatus_t create( \
infiniopHandle_t handle, \ infiniopHandle_t handle, \
Descriptor **desc_ptr, \ Descriptor **desc_ptr, \
...@@ -94,8 +36,13 @@ inline infiniStatus_t createRMSNormInfo(RMSNormInfo *info, infiniopTensorDescrip ...@@ -94,8 +36,13 @@ inline infiniStatus_t createRMSNormInfo(RMSNormInfo *info, infiniopTensorDescrip
infiniopTensorDescriptor_t x_desc, \ infiniopTensorDescriptor_t x_desc, \
infiniopTensorDescriptor_t w_desc, \ infiniopTensorDescriptor_t w_desc, \
float epsilon); \ float epsilon); \
infiniStatus_t calculate(void *workspace, size_t workspace_size, \ \
void *y, const void *x, const void *w, void *stream); \ infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *y, \
const void *x, \
const void *w, \
void *stream) const; \
}; \ }; \
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment