Commit 74ffbff5 authored by YdrMaster's avatar YdrMaster
Browse files

issue/121/refactor: 用 Result 修改 RmsNormInfo 构造流程,并与 Gemm 保持风格一致


Signed-off-by: default avatarYdrMaster <ydrml@hotmail.com>
parent 006fb46e
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "../../../reduce/cpu/reduce.h" #include "../../../reduce/cpu/reduce.h"
namespace op::rms_norm::cpu { namespace op::rms_norm::cpu {
Descriptor::~Descriptor() {} Descriptor::~Descriptor() {}
infiniStatus_t Descriptor::create( infiniStatus_t Descriptor::create(
...@@ -12,9 +13,9 @@ infiniStatus_t Descriptor::create( ...@@ -12,9 +13,9 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t x_desc, infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc, infiniopTensorDescriptor_t w_desc,
float epsilon) { float epsilon) {
RMSNormInfo info; auto result = RMSNormInfo::create(y_desc, x_desc, w_desc, epsilon);
CHECK_STATUS(createRMSNormInfo(&info, y_desc, x_desc, w_desc, epsilon)); CHECK_RESULT(result);
*desc_ptr = new Descriptor(nullptr, info, 0, handle->device, handle->device_id); *desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
...@@ -68,9 +69,10 @@ infiniStatus_t rmsnormF16(const RMSNormInfo *info, fp16_t *y, const fp16_t *x, c ...@@ -68,9 +69,10 @@ infiniStatus_t rmsnormF16(const RMSNormInfo *info, fp16_t *y, const fp16_t *x, c
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, infiniStatus_t Descriptor::calculate(
void *y, const void *x, const void *w, void *workspace, size_t workspace_size,
void *stream) { void *y, const void *x, const void *w,
void *stream) const {
if (_info.atype == INFINI_DTYPE_F16) { if (_info.atype == INFINI_DTYPE_F16) {
if (_info.wtype == INFINI_DTYPE_F16) { if (_info.wtype == INFINI_DTYPE_F16) {
CHECK_STATUS(rmsnormF16(&_info, (fp16_t *)y, (const fp16_t *)x, (const fp16_t *)w)); CHECK_STATUS(rmsnormF16(&_info, (fp16_t *)y, (const fp16_t *)x, (const fp16_t *)w));
......
#ifndef __RMS_NORM_INFO_H__
#define __RMS_NORM_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <vector>
namespace op::rms_norm {
class RMSNormInfo {
RMSNormInfo() = default;
public:
infiniDtype_t wtype;
infiniDtype_t atype;
float epsilon;
std::vector<size_t> shape;
std::vector<ptrdiff_t> y_strides;
std::vector<ptrdiff_t> x_strides;
size_t ndim() const { return shape.size(); }
size_t dim() const { return shape[ndim() - 1]; }
static utils::Result<RMSNormInfo> create(
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc,
float epsilon) {
auto atype = y_desc->dtype();
auto wtype = w_desc->dtype();
if (x_desc->dtype() != atype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (atype == INFINI_DTYPE_F16) {
if (wtype != INFINI_DTYPE_F16 && wtype != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (atype == INFINI_DTYPE_F32 || atype == INFINI_DTYPE_F64) {
if (atype != wtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (y_desc->ndim() != 2 || x_desc->ndim() != 2 || w_desc->ndim() != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t batch = y_desc->shape()[0];
size_t dim = y_desc->shape()[1];
if (x_desc->shape()[0] != batch || x_desc->shape()[1] != dim || w_desc->shape()[0] != dim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (w_desc->stride(0) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
if (x_desc->stride(1) != 1 || y_desc->stride(1) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
return utils::Result<RMSNormInfo>(RMSNormInfo{
wtype,
atype,
epsilon,
y_desc->shape(),
y_desc->strides(),
x_desc->strides(),
});
}
};
} // namespace op::rms_norm
#endif // __RMS_NORM_INFO_H__
#ifndef RMS_NORM_H #ifndef RMS_NORM_H
#define RMS_NORM_H #define RMS_NORM_H
#include "../../operator.h"
#include "../../tensor.h"
#include <vector>
struct RMSNormInfo {
infiniDtype_t wtype;
infiniDtype_t atype;
float epsilon;
std::vector<size_t> shape;
std::vector<ptrdiff_t> y_strides;
std::vector<ptrdiff_t> x_strides;
size_t ndim() { return shape.size(); }
size_t dim() { return shape[ndim() - 1]; }
};
inline infiniStatus_t createRMSNormInfo(RMSNormInfo *info, infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc,
float epsilon) {
auto atype = y_desc->dtype();
auto wtype = w_desc->dtype();
if (x_desc->dtype() != atype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (atype == INFINI_DTYPE_F16) {
if (wtype != INFINI_DTYPE_F16 && wtype != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (atype == INFINI_DTYPE_F32 || atype == INFINI_DTYPE_F64) {
if (atype != wtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
info->wtype = wtype;
info->atype = atype;
info->epsilon = epsilon;
if (y_desc->ndim() != 2 || x_desc->ndim() != 2 || w_desc->ndim() != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t batch = y_desc->shape()[0]; #include "../../operator.h"
size_t dim = y_desc->shape()[1]; #include "info.h"
if (x_desc->shape()[0] != batch || x_desc->shape()[1] != dim || w_desc->shape()[0] != dim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE; #define DESCRIPTOR(NAMESPACE) \
} \
namespace op::rms_norm::NAMESPACE { \
if (w_desc->stride(0) != 1) { class Descriptor final : public InfiniopDescriptor { \
return INFINI_STATUS_BAD_TENSOR_STRIDES; struct Opaque; \
} Opaque *_opaque; \
RMSNormInfo _info; \
if (x_desc->stride(1) != 1 || y_desc->stride(1) != 1) { size_t _workspace_size; \
return INFINI_STATUS_BAD_TENSOR_STRIDES; \
} Descriptor( \
Opaque *opaque, \
info->shape = std::move(y_desc->shape()); RMSNormInfo info, \
info->y_strides = std::move(y_desc->strides()); size_t workspace_size, \
info->x_strides = std::move(x_desc->strides()); infiniDevice_t device_type, \
int device_id) \
return INFINI_STATUS_SUCCESS; : InfiniopDescriptor{device_type, device_id}, \
} _opaque(opaque), \
_info(info), \
#define DESCRIPTOR(NAMESPACE) \ _workspace_size(workspace_size) {} \
namespace op::rms_norm::NAMESPACE { \ \
class Descriptor final : public InfiniopDescriptor { \ public: \
struct Opaque; \ ~Descriptor(); \
Opaque *_opaque; \ \
RMSNormInfo _info; \ size_t workspaceSize() const { return _workspace_size; } \
size_t _workspace_size; \ \
\ static infiniStatus_t create( \
Descriptor( \ infiniopHandle_t handle, \
Opaque *opaque, \ Descriptor **desc_ptr, \
RMSNormInfo info, \ infiniopTensorDescriptor_t y_desc, \
size_t workspace_size, \ infiniopTensorDescriptor_t x_desc, \
infiniDevice_t device_type, \ infiniopTensorDescriptor_t w_desc, \
int device_id) : InfiniopDescriptor{device_type, device_id}, \ float epsilon); \
_opaque(opaque), \ \
_info(info), \ infiniStatus_t calculate( \
_workspace_size(workspace_size) {} \ void *workspace, size_t workspace_size, \
\ void *y, \
public: \ const void *x, \
~Descriptor(); \ const void *w, \
size_t workspaceSize() const { return _workspace_size; } \ void *stream) const; \
static infiniStatus_t create( \ }; \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t x_desc, \
infiniopTensorDescriptor_t w_desc, \
float epsilon); \
infiniStatus_t calculate(void *workspace, size_t workspace_size, \
void *y, const void *x, const void *w, void *stream); \
}; \
} }
#endif // RMS_NORM_H #endif // RMS_NORM_H
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment