Unverified Commit e5bda616 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #122 from YdrMaster/main

issue/121/feat: 添加 Result 类型
parents beaf1e8c fd5d90c9
......@@ -2,57 +2,10 @@
#define CAUSAL_SOFTMAX_H
#include "../../operator.h"
#include "../../tensor.h"
#include <iostream>
#include <vector>
struct CausalSoftmaxInfo {
infiniDtype_t dtype;
size_t batch_size;
ptrdiff_t stride_b;
size_t seq_len;
ptrdiff_t stride_i;
size_t total_seq_len;
ptrdiff_t stride_j;
};
inline infiniStatus_t createCausalSoftmaxInfo(CausalSoftmaxInfo *info, infiniopTensorDescriptor_t y_desc) {
auto dtype = y_desc->dtype();
if (y_desc->dtype() != INFINI_DTYPE_F16 && y_desc->dtype() != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
info->dtype = dtype;
if (y_desc->ndim() != 2 && y_desc->ndim() != 3) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (y_desc->shape()[y_desc->ndim() - 1] < y_desc->shape()[y_desc->ndim() - 2]) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t batch_size = 1;
ptrdiff_t stride_b = 0;
size_t seq_len = y_desc->shape()[y_desc->ndim() - 2];
ptrdiff_t stride_i = y_desc->strides()[y_desc->ndim() - 2];
size_t total_seq_len = y_desc->shape()[y_desc->ndim() - 1];
ptrdiff_t stride_j = y_desc->strides()[y_desc->ndim() - 1];
if (y_desc->ndim() == 3) {
stride_b = y_desc->strides()[0];
batch_size = y_desc->shape()[0];
}
info->batch_size = batch_size;
info->stride_b = stride_b;
info->seq_len = seq_len;
info->stride_i = stride_i;
info->total_seq_len = total_seq_len;
info->stride_j = stride_j;
return INFINI_STATUS_SUCCESS;
}
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::causal_softmax::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
......@@ -65,20 +18,26 @@ inline infiniStatus_t createCausalSoftmaxInfo(CausalSoftmaxInfo *info, infiniopT
CausalSoftmaxInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) : InfiniopDescriptor{device_type, device_id}, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc); \
infiniStatus_t calculate(void *workspace, size_t workspace_size, \
void *data, void *stream); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *data, \
void *stream) const; \
}; \
}
......
......@@ -3,15 +3,16 @@
#include "../../../reduce/cpu/reduce.h"
namespace op::causal_softmax::cpu {
Descriptor::~Descriptor() {}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc) {
CausalSoftmaxInfo info;
CHECK_STATUS(createCausalSoftmaxInfo(&info, y_desc));
*desc_ptr = new Descriptor(nullptr, info, 0, handle->device, handle->device_id);
auto result = CausalSoftmaxInfo::create(y_desc);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
......@@ -53,9 +54,11 @@ infiniStatus_t causal_softmax(const CausalSoftmaxInfo *info, T *data) {
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *data,
void *stream) {
void *stream) const {
if (_info.dtype == INFINI_DTYPE_F16) {
CHECK_STATUS(causal_softmax<fp16_t>(&_info, (fp16_t *)data));
} else if (_info.dtype == INFINI_DTYPE_F32) {
......
#ifndef __CAUSAL_SOFTMAX_INFO_H__
#define __CAUSAL_SOFTMAX_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <vector>
namespace op::causal_softmax {
class CausalSoftmaxInfo {
CausalSoftmaxInfo() = default;
public:
infiniDtype_t dtype;
size_t batch_size;
ptrdiff_t stride_b;
size_t seq_len;
ptrdiff_t stride_i;
size_t total_seq_len;
ptrdiff_t stride_j;
static utils::Result<CausalSoftmaxInfo> create(infiniopTensorDescriptor_t y_desc) {
auto dtype = y_desc->dtype();
if (y_desc->dtype() != INFINI_DTYPE_F16 && y_desc->dtype() != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (y_desc->ndim() != 2 && y_desc->ndim() != 3) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (y_desc->shape()[y_desc->ndim() - 1] < y_desc->shape()[y_desc->ndim() - 2]) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t batch_size = 1;
ptrdiff_t stride_b = 0;
size_t seq_len = y_desc->shape()[y_desc->ndim() - 2];
ptrdiff_t stride_i = y_desc->strides()[y_desc->ndim() - 2];
size_t total_seq_len = y_desc->shape()[y_desc->ndim() - 1];
ptrdiff_t stride_j = y_desc->strides()[y_desc->ndim() - 1];
if (y_desc->ndim() == 3) {
stride_b = y_desc->strides()[0];
batch_size = y_desc->shape()[0];
}
return utils::Result<CausalSoftmaxInfo>(CausalSoftmaxInfo{
dtype,
batch_size,
stride_b,
seq_len,
stride_i,
total_seq_len,
stride_j});
}
};
} // namespace op::causal_softmax
#endif // __CAUSAL_SOFTMAX_INFO_H__
......@@ -38,14 +38,12 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
infiniStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::ROW_MAJOR);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::ROW_MAJOR);
CHECK_RESULT(result);
auto info = result.take();
auto c = new aclnnTensorDescriptor(toAclDataType(c_desc->dtype()),
{static_cast<int64_t>(info.c_matrix.rows), static_cast<int64_t>(info.c_matrix.cols)},
{static_cast<int64_t>(info.m), static_cast<int64_t>(info.n)},
{info.c_matrix.row_stride, info.c_matrix.col_stride});
auto a = new aclnnTensorDescriptor(toAclDataType(a_desc->dtype()),
{static_cast<int64_t>(info.a_matrix.rows), static_cast<int64_t>(info.a_matrix.cols)},
......
......@@ -71,11 +71,9 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
infiniStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::ROW_MAJOR);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::ROW_MAJOR);
CHECK_RESULT(result);
auto info = result.take();
cnnlTensorDescriptor_t a, b, c;
CHECK_BANG(cnnlCreateTensorDescriptor(&a));
......
#ifndef __BLAS_H__
#define __BLAS_H__
#include "../../operator.h"
#include "../../tensor.h"
#include <algorithm>
namespace op::gemm {
struct BlasMatrix {
size_t ndim;
size_t batch;
ptrdiff_t stride;
size_t rows;
size_t cols;
ptrdiff_t row_stride;
ptrdiff_t col_stride;
BlasMatrix() = default;
BlasMatrix(infiniopTensorDescriptor_t layout, infiniStatus_t *status) {
if (layout->ndim() == 2) {
ndim = 2;
batch = 1;
stride = 0;
rows = layout->dim(0);
cols = layout->dim(1);
row_stride = layout->stride(0);
col_stride = layout->stride(1);
} else if (layout->ndim() == 3) {
ndim = 3;
batch = layout->dim(0);
stride = batch == 1 ? 0 : layout->stride(0);
rows = layout->dim(1);
cols = layout->dim(2);
row_stride = layout->stride(1);
col_stride = layout->stride(2);
} else {
*status = INFINI_STATUS_BAD_TENSOR_SHAPE;
return;
}
if (row_stride != 1 && col_stride != 1) {
*status = INFINI_STATUS_BAD_TENSOR_STRIDES;
return;
}
*status = INFINI_STATUS_SUCCESS;
}
bool match_batch(size_t _batch) const {
return batch == _batch || batch == 1;
}
void transpose() {
std::swap(rows, cols);
std::swap(row_stride, col_stride);
}
ptrdiff_t ld() const {
return row_stride == 1 ? col_stride : row_stride;
}
};
enum class MatrixLayout : char {
COL_MAJOR,
ROW_MAJOR,
};
struct MatmulInfo {
BlasMatrix a_matrix;
BlasMatrix b_matrix;
BlasMatrix c_matrix;
size_t m, n, k, batch;
bool is_transed = false;
MatmulInfo(infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniStatus_t *status,
MatrixLayout layout) {
a_matrix = BlasMatrix(a_desc, status);
if (*status != INFINI_STATUS_SUCCESS) {
return;
}
b_matrix = BlasMatrix(b_desc, status);
if (*status != INFINI_STATUS_SUCCESS) {
return;
}
c_matrix = BlasMatrix(c_desc, status);
if (*status != INFINI_STATUS_SUCCESS) {
return;
}
if (c_matrix.rows != a_matrix.rows || c_matrix.cols != b_matrix.cols || a_matrix.cols != b_matrix.rows) {
*status = INFINI_STATUS_BAD_TENSOR_SHAPE;
return;
}
batch = c_matrix.batch;
if (!a_matrix.match_batch(batch) || !b_matrix.match_batch(batch)) {
*status = INFINI_STATUS_BAD_TENSOR_SHAPE;
return;
}
if ((layout == MatrixLayout::COL_MAJOR && c_matrix.col_stride == 1)
|| (layout == MatrixLayout::ROW_MAJOR && c_matrix.row_stride == 1)) {
c_matrix.transpose();
b_matrix.transpose();
a_matrix.transpose();
std::swap(a_matrix, b_matrix);
is_transed = true;
}
m = c_matrix.rows;
n = c_matrix.cols;
k = a_matrix.cols;
}
};
} // namespace op::gemm
#endif // __BLAS_H__
......@@ -18,14 +18,11 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
infiniStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::COL_MAJOR);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(
dtype, info, 0,
dtype, result.take(), 0,
nullptr,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
......
......@@ -24,14 +24,11 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
infiniStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::COL_MAJOR);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(
dtype, info, 0,
dtype, result.take(), 0,
new Opaque{handle->internal()},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
......
......@@ -2,7 +2,7 @@
#define __GEMM_H__
#include "../../operator.h"
#include "blas.h"
#include "info.h"
/**
* # 关于 `DESCRIPTOR(NAMESPACE)` 和 `struct Opaque;` 的说明
......@@ -52,6 +52,7 @@
Opaque *_opaque; \
infiniDtype_t _dtype; \
MatmulInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
infiniDtype_t dtype, \
......@@ -64,13 +65,13 @@
_opaque(opaque), \
_dtype(dtype), \
_info(info), \
workspace_size(workspace_size_) {} \
_workspace_size(workspace_size_) {} \
\
public: \
size_t workspace_size; \
\
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
......@@ -79,8 +80,7 @@
infiniopTensorDescriptor_t b_desc); \
\
infiniStatus_t calculate( \
void *workspace, \
size_t workspace_size, \
void *workspace, size_t workspace_size, \
void *c, \
float beta, \
const void *a, \
......
#ifndef __GEMM_INFO_H__
#define __GEMM_INFO_H__
#include "../../../utils.h"
#include "../../operator.h"
#include "../../tensor.h"
#include <algorithm>
namespace op::gemm {
class BlasMatrix {
BlasMatrix() = default;
public:
size_t ndim;
size_t batch;
ptrdiff_t stride;
size_t rows;
size_t cols;
ptrdiff_t row_stride;
ptrdiff_t col_stride;
static utils::Result<BlasMatrix> create(infiniopTensorDescriptor_t layout) {
BlasMatrix ans;
if (layout->ndim() == 2) {
ans.ndim = 2;
ans.batch = 1;
ans.stride = 0;
ans.rows = layout->dim(0);
ans.cols = layout->dim(1);
ans.row_stride = layout->stride(0);
ans.col_stride = layout->stride(1);
} else if (layout->ndim() == 3) {
ans.ndim = 3;
ans.batch = layout->dim(0);
ans.stride = ans.batch == 1 ? 0 : layout->stride(0);
ans.rows = layout->dim(1);
ans.cols = layout->dim(2);
ans.row_stride = layout->stride(1);
ans.col_stride = layout->stride(2);
} else {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (ans.row_stride != 1 && ans.col_stride != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
return utils::Result<BlasMatrix>(ans);
}
bool match_batch(size_t _batch) const {
return batch == _batch || batch == 1;
}
void transpose() {
std::swap(rows, cols);
std::swap(row_stride, col_stride);
}
ptrdiff_t ld() const {
return row_stride == 1 ? col_stride : row_stride;
}
};
enum class MatrixLayout : char {
COL_MAJOR,
ROW_MAJOR,
};
class MatmulInfo {
MatmulInfo() = default;
public:
BlasMatrix a_matrix;
BlasMatrix b_matrix;
BlasMatrix c_matrix;
size_t m, n, k, batch;
bool is_transed;
static utils::Result<MatmulInfo> create(
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
MatrixLayout layout) {
auto a_matrix = BlasMatrix::create(a_desc);
CHECK_RESULT(a_matrix);
auto b_matrix = BlasMatrix::create(b_desc);
CHECK_RESULT(b_matrix);
auto c_matrix = BlasMatrix::create(c_desc);
CHECK_RESULT(c_matrix);
if (c_matrix->rows != a_matrix->rows || c_matrix->cols != b_matrix->cols || a_matrix->cols != b_matrix->rows) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
auto batch = c_matrix->batch;
if (!a_matrix->match_batch(batch) || !b_matrix->match_batch(batch)) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
auto is_transed = false;
if ((layout == MatrixLayout::COL_MAJOR && c_matrix->col_stride == 1)
|| (layout == MatrixLayout::ROW_MAJOR && c_matrix->row_stride == 1)) {
c_matrix->transpose();
b_matrix->transpose();
a_matrix->transpose();
std::swap(a_matrix, b_matrix);
is_transed = true;
}
auto m = c_matrix->rows;
auto n = c_matrix->cols;
auto k = a_matrix->cols;
return utils::Result<MatmulInfo>(MatmulInfo{
a_matrix.take(),
b_matrix.take(),
c_matrix.take(),
m,
n,
k,
batch,
is_transed});
}
};
} // namespace op::gemm
#endif // __GEMM_INFO_H__
......@@ -27,14 +27,11 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
infiniStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::ROW_MAJOR);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::ROW_MAJOR);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(
dtype, info, 0,
dtype, result.take(), 0,
new Opaque{handle->internal()},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
......
......@@ -25,14 +25,11 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
infiniStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::COL_MAJOR);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(
dtype, info, 0,
dtype, result.take(), 0,
new Opaque{handle->internal()},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
......
......@@ -72,7 +72,7 @@ infiniopGetGemmWorkspaceSize(
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<const op::gemm::NAMESPACE::Descriptor *>(desc)->workspace_size; \
*size = reinterpret_cast<const op::gemm::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS
switch (desc->device_type) {
......
......@@ -27,14 +27,12 @@ infiniStatus_t Descriptor::create(
auto dst_strides = y_desc->strides().data();
auto src_strides = x_desc->strides().data();
auto element_size = infiniSizeOf(dtype);
auto meta = utils::RearrangeMeta::create(shape, dst_strides, src_strides, ndim, element_size);
if (!meta) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
auto result = utils::RearrangeMeta::create(shape, dst_strides, src_strides, ndim, element_size);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(
std::move(*meta),
result.take(),
nullptr,
handle->device,
handle->device_id);
......
......@@ -3,6 +3,7 @@
#include "../../../reduce/cpu/reduce.h"
namespace op::rms_norm::cpu {
Descriptor::~Descriptor() {}
infiniStatus_t Descriptor::create(
......@@ -12,9 +13,9 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc,
float epsilon) {
RMSNormInfo info;
CHECK_STATUS(createRMSNormInfo(&info, y_desc, x_desc, w_desc, epsilon));
*desc_ptr = new Descriptor(nullptr, info, 0, handle->device, handle->device_id);
auto result = RMSNormInfo::create(y_desc, x_desc, w_desc, epsilon);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
......@@ -68,9 +69,10 @@ infiniStatus_t rmsnormF16(const RMSNormInfo *info, fp16_t *y, const fp16_t *x, c
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, const void *x, const void *w,
void *stream) {
void *stream) const {
if (_info.atype == INFINI_DTYPE_F16) {
if (_info.wtype == INFINI_DTYPE_F16) {
CHECK_STATUS(rmsnormF16(&_info, (fp16_t *)y, (const fp16_t *)x, (const fp16_t *)w));
......
#ifndef __RMS_NORM_INFO_H__
#define __RMS_NORM_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <vector>
namespace op::rms_norm {
class RMSNormInfo {
RMSNormInfo() = default;
public:
infiniDtype_t wtype;
infiniDtype_t atype;
float epsilon;
std::vector<size_t> shape;
std::vector<ptrdiff_t> y_strides;
std::vector<ptrdiff_t> x_strides;
size_t ndim() const { return shape.size(); }
size_t dim() const { return shape[ndim() - 1]; }
static utils::Result<RMSNormInfo> create(
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc,
float epsilon) {
auto atype = y_desc->dtype();
auto wtype = w_desc->dtype();
if (x_desc->dtype() != atype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (atype == INFINI_DTYPE_F16) {
if (wtype != INFINI_DTYPE_F16 && wtype != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (atype == INFINI_DTYPE_F32 || atype == INFINI_DTYPE_F64) {
if (atype != wtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (y_desc->ndim() != 2 || x_desc->ndim() != 2 || w_desc->ndim() != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t batch = y_desc->shape()[0];
size_t dim = y_desc->shape()[1];
if (x_desc->shape()[0] != batch || x_desc->shape()[1] != dim || w_desc->shape()[0] != dim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (w_desc->stride(0) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
if (x_desc->stride(1) != 1 || y_desc->stride(1) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
return utils::Result<RMSNormInfo>(RMSNormInfo{
wtype,
atype,
epsilon,
y_desc->shape(),
y_desc->strides(),
x_desc->strides(),
});
}
};
} // namespace op::rms_norm
#endif // __RMS_NORM_INFO_H__
#ifndef RMS_NORM_H
#define RMS_NORM_H
#include "../../operator.h"
#include "../../tensor.h"
#include <vector>
struct RMSNormInfo {
infiniDtype_t wtype;
infiniDtype_t atype;
float epsilon;
std::vector<size_t> shape;
std::vector<ptrdiff_t> y_strides;
std::vector<ptrdiff_t> x_strides;
size_t ndim() { return shape.size(); }
size_t dim() { return shape[ndim() - 1]; }
};
inline infiniStatus_t createRMSNormInfo(RMSNormInfo *info, infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc,
float epsilon) {
auto atype = y_desc->dtype();
auto wtype = w_desc->dtype();
if (x_desc->dtype() != atype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (atype == INFINI_DTYPE_F16) {
if (wtype != INFINI_DTYPE_F16 && wtype != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (atype == INFINI_DTYPE_F32 || atype == INFINI_DTYPE_F64) {
if (atype != wtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
info->wtype = wtype;
info->atype = atype;
info->epsilon = epsilon;
if (y_desc->ndim() != 2 || x_desc->ndim() != 2 || w_desc->ndim() != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t batch = y_desc->shape()[0];
size_t dim = y_desc->shape()[1];
if (x_desc->shape()[0] != batch || x_desc->shape()[1] != dim || w_desc->shape()[0] != dim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (w_desc->stride(0) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
if (x_desc->stride(1) != 1 || y_desc->stride(1) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
info->shape = std::move(y_desc->shape());
info->y_strides = std::move(y_desc->strides());
info->x_strides = std::move(x_desc->strides());
return INFINI_STATUS_SUCCESS;
}
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::rms_norm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
......@@ -79,14 +18,17 @@ inline infiniStatus_t createRMSNormInfo(RMSNormInfo *info, infiniopTensorDescrip
RMSNormInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) : InfiniopDescriptor{device_type, device_id}, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
......@@ -94,8 +36,13 @@ inline infiniStatus_t createRMSNormInfo(RMSNormInfo *info, infiniopTensorDescrip
infiniopTensorDescriptor_t x_desc, \
infiniopTensorDescriptor_t w_desc, \
float epsilon); \
infiniStatus_t calculate(void *workspace, size_t workspace_size, \
void *y, const void *x, const void *w, void *stream); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *y, \
const void *x, \
const void *w, \
void *stream) const; \
}; \
}
......
#ifndef INFINIUTILS_H
#define INFINIUTILS_H
#include "infinicore.h"
#include "utils/check.h"
#include "utils/custom_types.h"
#include "utils/rearrange.h"
......
......@@ -13,7 +13,7 @@ namespace utils {
RearrangeMeta::RearrangeMeta(std::vector<ptrdiff_t> meta)
: _meta(std::move(meta)) {}
std::optional<RearrangeMeta> RearrangeMeta::create(
Result<RearrangeMeta> RearrangeMeta::create(
const size_t *shape,
const ptrdiff_t *dst_strides_,
const ptrdiff_t *src_strides_,
......@@ -32,7 +32,9 @@ std::optional<RearrangeMeta> RearrangeMeta::create(
// 剔除初始的 1 长维度
if (shape[i] != 1) {
auto sd = dst_strides_[i] * unit, ss = src_strides_[i] * unit;
// assert (sd != 0)
if (sd == 0) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
dims.push_back(Dim{shape[i], sd, ss});
}
}
......@@ -81,7 +83,7 @@ std::optional<RearrangeMeta> RearrangeMeta::create(
for (ptrdiff_t i = ndim; i > 0; --i) {
meta[1 + i - 1] *= meta[1 + i];
}
return {RearrangeMeta(std::move(meta))};
return Result<RearrangeMeta>(meta);
}
size_t RearrangeMeta::ndim() const { return (_meta.size() - 2) / 3; }
......
#ifndef __INFINIUTILS_REARRANGE_H__
#define __INFINIUTILS_REARRANGE_H__
#include <optional>
#include <stddef.h>
#include "result.hpp"
#include <cstddef>
#include <vector>
namespace utils {
......@@ -12,7 +12,7 @@ class RearrangeMeta {
RearrangeMeta(std::vector<ptrdiff_t>);
public:
static std::optional<RearrangeMeta> create(
static Result<RearrangeMeta> create(
const size_t *shape,
const ptrdiff_t *dst_strides,
const ptrdiff_t *src_strides,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment