"git@developer.sourcefind.cn:gaoqiong/pybind11.git" did not exist on "8f2f7cd61cae57680baa3cef2c00c307f51b5146"
Unverified Commit e5bda616 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #122 from YdrMaster/main

issue/121/feat: 添加 Result 类型
parents beaf1e8c fd5d90c9
...@@ -2,84 +2,43 @@ ...@@ -2,84 +2,43 @@
#define CAUSAL_SOFTMAX_H #define CAUSAL_SOFTMAX_H
#include "../../operator.h" #include "../../operator.h"
#include "../../tensor.h" #include "info.h"
#include <iostream>
#include <vector> #define DESCRIPTOR(NAMESPACE) \
\
struct CausalSoftmaxInfo { namespace op::causal_softmax::NAMESPACE { \
infiniDtype_t dtype; class Descriptor final : public InfiniopDescriptor { \
size_t batch_size; struct Opaque; \
ptrdiff_t stride_b; Opaque *_opaque; \
size_t seq_len; CausalSoftmaxInfo _info; \
ptrdiff_t stride_i; size_t _workspace_size; \
size_t total_seq_len; \
ptrdiff_t stride_j; Descriptor( \
}; Opaque *opaque, \
CausalSoftmaxInfo info, \
inline infiniStatus_t createCausalSoftmaxInfo(CausalSoftmaxInfo *info, infiniopTensorDescriptor_t y_desc) { size_t workspace_size, \
auto dtype = y_desc->dtype(); infiniDevice_t device_type, \
if (y_desc->dtype() != INFINI_DTYPE_F16 && y_desc->dtype() != INFINI_DTYPE_F32) { int device_id) \
return INFINI_STATUS_BAD_TENSOR_DTYPE; : InfiniopDescriptor{device_type, device_id}, \
} _opaque(opaque), \
info->dtype = dtype; _info(info), \
_workspace_size(workspace_size) {} \
if (y_desc->ndim() != 2 && y_desc->ndim() != 3) { \
return INFINI_STATUS_BAD_TENSOR_SHAPE; public: \
} ~Descriptor(); \
\
if (y_desc->shape()[y_desc->ndim() - 1] < y_desc->shape()[y_desc->ndim() - 2]) { size_t workspaceSize() const { return _workspace_size; } \
return INFINI_STATUS_BAD_TENSOR_SHAPE; \
} static infiniStatus_t create( \
infiniopHandle_t handle, \
size_t batch_size = 1; Descriptor **desc_ptr, \
ptrdiff_t stride_b = 0; infiniopTensorDescriptor_t y_desc); \
size_t seq_len = y_desc->shape()[y_desc->ndim() - 2]; \
ptrdiff_t stride_i = y_desc->strides()[y_desc->ndim() - 2]; infiniStatus_t calculate( \
size_t total_seq_len = y_desc->shape()[y_desc->ndim() - 1]; void *workspace, size_t workspace_size, \
ptrdiff_t stride_j = y_desc->strides()[y_desc->ndim() - 1]; void *data, \
if (y_desc->ndim() == 3) { void *stream) const; \
stride_b = y_desc->strides()[0]; }; \
batch_size = y_desc->shape()[0];
}
info->batch_size = batch_size;
info->stride_b = stride_b;
info->seq_len = seq_len;
info->stride_i = stride_i;
info->total_seq_len = total_seq_len;
info->stride_j = stride_j;
return INFINI_STATUS_SUCCESS;
}
#define DESCRIPTOR(NAMESPACE) \
namespace op::causal_softmax::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
CausalSoftmaxInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
CausalSoftmaxInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) : InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
size_t workspaceSize() const { return _workspace_size; } \
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc); \
infiniStatus_t calculate(void *workspace, size_t workspace_size, \
void *data, void *stream); \
}; \
} }
#endif // CAUSAL_SOFTMAX_H #endif // CAUSAL_SOFTMAX_H
...@@ -3,15 +3,16 @@ ...@@ -3,15 +3,16 @@
#include "../../../reduce/cpu/reduce.h" #include "../../../reduce/cpu/reduce.h"
namespace op::causal_softmax::cpu { namespace op::causal_softmax::cpu {
Descriptor::~Descriptor() {} Descriptor::~Descriptor() {}
infiniStatus_t Descriptor::create( infiniStatus_t Descriptor::create(
infiniopHandle_t handle, infiniopHandle_t handle,
Descriptor **desc_ptr, Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc) { infiniopTensorDescriptor_t y_desc) {
CausalSoftmaxInfo info; auto result = CausalSoftmaxInfo::create(y_desc);
CHECK_STATUS(createCausalSoftmaxInfo(&info, y_desc)); CHECK_RESULT(result);
*desc_ptr = new Descriptor(nullptr, info, 0, handle->device, handle->device_id); *desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
...@@ -53,9 +54,11 @@ infiniStatus_t causal_softmax(const CausalSoftmaxInfo *info, T *data) { ...@@ -53,9 +54,11 @@ infiniStatus_t causal_softmax(const CausalSoftmaxInfo *info, T *data) {
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, infiniStatus_t Descriptor::calculate(
void *data, void *workspace, size_t workspace_size,
void *stream) { void *data,
void *stream) const {
if (_info.dtype == INFINI_DTYPE_F16) { if (_info.dtype == INFINI_DTYPE_F16) {
CHECK_STATUS(causal_softmax<fp16_t>(&_info, (fp16_t *)data)); CHECK_STATUS(causal_softmax<fp16_t>(&_info, (fp16_t *)data));
} else if (_info.dtype == INFINI_DTYPE_F32) { } else if (_info.dtype == INFINI_DTYPE_F32) {
......
#ifndef __CAUSAL_SOFTMAX_INFO_H__
#define __CAUSAL_SOFTMAX_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <vector>
namespace op::causal_softmax {
class CausalSoftmaxInfo {
CausalSoftmaxInfo() = default;
public:
infiniDtype_t dtype;
size_t batch_size;
ptrdiff_t stride_b;
size_t seq_len;
ptrdiff_t stride_i;
size_t total_seq_len;
ptrdiff_t stride_j;
static utils::Result<CausalSoftmaxInfo> create(infiniopTensorDescriptor_t y_desc) {
auto dtype = y_desc->dtype();
if (y_desc->dtype() != INFINI_DTYPE_F16 && y_desc->dtype() != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (y_desc->ndim() != 2 && y_desc->ndim() != 3) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (y_desc->shape()[y_desc->ndim() - 1] < y_desc->shape()[y_desc->ndim() - 2]) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t batch_size = 1;
ptrdiff_t stride_b = 0;
size_t seq_len = y_desc->shape()[y_desc->ndim() - 2];
ptrdiff_t stride_i = y_desc->strides()[y_desc->ndim() - 2];
size_t total_seq_len = y_desc->shape()[y_desc->ndim() - 1];
ptrdiff_t stride_j = y_desc->strides()[y_desc->ndim() - 1];
if (y_desc->ndim() == 3) {
stride_b = y_desc->strides()[0];
batch_size = y_desc->shape()[0];
}
return utils::Result<CausalSoftmaxInfo>(CausalSoftmaxInfo{
dtype,
batch_size,
stride_b,
seq_len,
stride_i,
total_seq_len,
stride_j});
}
};
} // namespace op::causal_softmax
#endif // __CAUSAL_SOFTMAX_INFO_H__
...@@ -38,14 +38,12 @@ infiniStatus_t Descriptor::create( ...@@ -38,14 +38,12 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
infiniStatus_t status; auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::ROW_MAJOR);
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::ROW_MAJOR); CHECK_RESULT(result);
if (status != INFINI_STATUS_SUCCESS) { auto info = result.take();
return status;
}
auto c = new aclnnTensorDescriptor(toAclDataType(c_desc->dtype()), auto c = new aclnnTensorDescriptor(toAclDataType(c_desc->dtype()),
{static_cast<int64_t>(info.c_matrix.rows), static_cast<int64_t>(info.c_matrix.cols)}, {static_cast<int64_t>(info.m), static_cast<int64_t>(info.n)},
{info.c_matrix.row_stride, info.c_matrix.col_stride}); {info.c_matrix.row_stride, info.c_matrix.col_stride});
auto a = new aclnnTensorDescriptor(toAclDataType(a_desc->dtype()), auto a = new aclnnTensorDescriptor(toAclDataType(a_desc->dtype()),
{static_cast<int64_t>(info.a_matrix.rows), static_cast<int64_t>(info.a_matrix.cols)}, {static_cast<int64_t>(info.a_matrix.rows), static_cast<int64_t>(info.a_matrix.cols)},
......
...@@ -71,11 +71,9 @@ infiniStatus_t Descriptor::create( ...@@ -71,11 +71,9 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
infiniStatus_t status; auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::ROW_MAJOR);
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::ROW_MAJOR); CHECK_RESULT(result);
if (status != INFINI_STATUS_SUCCESS) { auto info = result.take();
return status;
}
cnnlTensorDescriptor_t a, b, c; cnnlTensorDescriptor_t a, b, c;
CHECK_BANG(cnnlCreateTensorDescriptor(&a)); CHECK_BANG(cnnlCreateTensorDescriptor(&a));
......
#ifndef __BLAS_H__
#define __BLAS_H__
#include "../../operator.h"
#include "../../tensor.h"
#include <algorithm>
namespace op::gemm {
struct BlasMatrix {
size_t ndim;
size_t batch;
ptrdiff_t stride;
size_t rows;
size_t cols;
ptrdiff_t row_stride;
ptrdiff_t col_stride;
BlasMatrix() = default;
BlasMatrix(infiniopTensorDescriptor_t layout, infiniStatus_t *status) {
if (layout->ndim() == 2) {
ndim = 2;
batch = 1;
stride = 0;
rows = layout->dim(0);
cols = layout->dim(1);
row_stride = layout->stride(0);
col_stride = layout->stride(1);
} else if (layout->ndim() == 3) {
ndim = 3;
batch = layout->dim(0);
stride = batch == 1 ? 0 : layout->stride(0);
rows = layout->dim(1);
cols = layout->dim(2);
row_stride = layout->stride(1);
col_stride = layout->stride(2);
} else {
*status = INFINI_STATUS_BAD_TENSOR_SHAPE;
return;
}
if (row_stride != 1 && col_stride != 1) {
*status = INFINI_STATUS_BAD_TENSOR_STRIDES;
return;
}
*status = INFINI_STATUS_SUCCESS;
}
bool match_batch(size_t _batch) const {
return batch == _batch || batch == 1;
}
void transpose() {
std::swap(rows, cols);
std::swap(row_stride, col_stride);
}
ptrdiff_t ld() const {
return row_stride == 1 ? col_stride : row_stride;
}
};
enum class MatrixLayout : char {
COL_MAJOR,
ROW_MAJOR,
};
struct MatmulInfo {
BlasMatrix a_matrix;
BlasMatrix b_matrix;
BlasMatrix c_matrix;
size_t m, n, k, batch;
bool is_transed = false;
MatmulInfo(infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniStatus_t *status,
MatrixLayout layout) {
a_matrix = BlasMatrix(a_desc, status);
if (*status != INFINI_STATUS_SUCCESS) {
return;
}
b_matrix = BlasMatrix(b_desc, status);
if (*status != INFINI_STATUS_SUCCESS) {
return;
}
c_matrix = BlasMatrix(c_desc, status);
if (*status != INFINI_STATUS_SUCCESS) {
return;
}
if (c_matrix.rows != a_matrix.rows || c_matrix.cols != b_matrix.cols || a_matrix.cols != b_matrix.rows) {
*status = INFINI_STATUS_BAD_TENSOR_SHAPE;
return;
}
batch = c_matrix.batch;
if (!a_matrix.match_batch(batch) || !b_matrix.match_batch(batch)) {
*status = INFINI_STATUS_BAD_TENSOR_SHAPE;
return;
}
if ((layout == MatrixLayout::COL_MAJOR && c_matrix.col_stride == 1)
|| (layout == MatrixLayout::ROW_MAJOR && c_matrix.row_stride == 1)) {
c_matrix.transpose();
b_matrix.transpose();
a_matrix.transpose();
std::swap(a_matrix, b_matrix);
is_transed = true;
}
m = c_matrix.rows;
n = c_matrix.cols;
k = a_matrix.cols;
}
};
} // namespace op::gemm
#endif // __BLAS_H__
...@@ -18,14 +18,11 @@ infiniStatus_t Descriptor::create( ...@@ -18,14 +18,11 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
infiniStatus_t status; auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::COL_MAJOR); CHECK_RESULT(result);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
dtype, info, 0, dtype, result.take(), 0,
nullptr, nullptr,
handle->device, handle->device_id); handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
......
...@@ -24,14 +24,11 @@ infiniStatus_t Descriptor::create( ...@@ -24,14 +24,11 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
infiniStatus_t status; auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::COL_MAJOR); CHECK_RESULT(result);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
dtype, info, 0, dtype, result.take(), 0,
new Opaque{handle->internal()}, new Opaque{handle->internal()},
handle->device, handle->device_id); handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define __GEMM_H__ #define __GEMM_H__
#include "../../operator.h" #include "../../operator.h"
#include "blas.h" #include "info.h"
/** /**
* # 关于 `DESCRIPTOR(NAMESPACE)` 和 `struct Opaque;` 的说明 * # 关于 `DESCRIPTOR(NAMESPACE)` 和 `struct Opaque;` 的说明
...@@ -44,50 +44,50 @@ ...@@ -44,50 +44,50 @@
* 这个宏仅适用于矩阵乘,但这种模式很容易复制到其他算子,以简化和规范算子的声明。 * 这个宏仅适用于矩阵乘,但这种模式很容易复制到其他算子,以简化和规范算子的声明。
*/ */
#define DESCRIPTOR(NAMESPACE) \ #define DESCRIPTOR(NAMESPACE) \
\ \
namespace op::gemm::NAMESPACE { \ namespace op::gemm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \ class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \ struct Opaque; \
Opaque *_opaque; \ Opaque *_opaque; \
infiniDtype_t _dtype; \ infiniDtype_t _dtype; \
MatmulInfo _info; \ MatmulInfo _info; \
\ size_t _workspace_size; \
Descriptor( \ \
infiniDtype_t dtype, \ Descriptor( \
MatmulInfo info, \ infiniDtype_t dtype, \
size_t workspace_size_, \ MatmulInfo info, \
Opaque *opaque, \ size_t workspace_size_, \
infiniDevice_t device_type, \ Opaque *opaque, \
int device_id) \ infiniDevice_t device_type, \
: InfiniopDescriptor{device_type, device_id}, \ int device_id) \
_opaque(opaque), \ : InfiniopDescriptor{device_type, device_id}, \
_dtype(dtype), \ _opaque(opaque), \
_info(info), \ _dtype(dtype), \
workspace_size(workspace_size_) {} \ _info(info), \
\ _workspace_size(workspace_size_) {} \
public: \ \
size_t workspace_size; \ public: \
\ ~Descriptor(); \
~Descriptor(); \ \
\ size_t workspaceSize() const { return _workspace_size; } \
static infiniStatus_t create( \ \
infiniopHandle_t handle, \ static infiniStatus_t create( \
Descriptor **desc_ptr, \ infiniopHandle_t handle, \
infiniopTensorDescriptor_t c_desc, \ Descriptor **desc_ptr, \
infiniopTensorDescriptor_t a_desc, \ infiniopTensorDescriptor_t c_desc, \
infiniopTensorDescriptor_t b_desc); \ infiniopTensorDescriptor_t a_desc, \
\ infiniopTensorDescriptor_t b_desc); \
infiniStatus_t calculate( \ \
void *workspace, \ infiniStatus_t calculate( \
size_t workspace_size, \ void *workspace, size_t workspace_size, \
void *c, \ void *c, \
float beta, \ float beta, \
const void *a, \ const void *a, \
const void *b, \ const void *b, \
float alpha, \ float alpha, \
void *stream) const; \ void *stream) const; \
}; \ }; \
} }
#endif // __GEMM_H__ #endif // __GEMM_H__
#ifndef __GEMM_INFO_H__
#define __GEMM_INFO_H__
#include "../../../utils.h"
#include "../../operator.h"
#include "../../tensor.h"
#include <algorithm>
namespace op::gemm {
class BlasMatrix {
BlasMatrix() = default;
public:
size_t ndim;
size_t batch;
ptrdiff_t stride;
size_t rows;
size_t cols;
ptrdiff_t row_stride;
ptrdiff_t col_stride;
static utils::Result<BlasMatrix> create(infiniopTensorDescriptor_t layout) {
BlasMatrix ans;
if (layout->ndim() == 2) {
ans.ndim = 2;
ans.batch = 1;
ans.stride = 0;
ans.rows = layout->dim(0);
ans.cols = layout->dim(1);
ans.row_stride = layout->stride(0);
ans.col_stride = layout->stride(1);
} else if (layout->ndim() == 3) {
ans.ndim = 3;
ans.batch = layout->dim(0);
ans.stride = ans.batch == 1 ? 0 : layout->stride(0);
ans.rows = layout->dim(1);
ans.cols = layout->dim(2);
ans.row_stride = layout->stride(1);
ans.col_stride = layout->stride(2);
} else {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (ans.row_stride != 1 && ans.col_stride != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
return utils::Result<BlasMatrix>(ans);
}
bool match_batch(size_t _batch) const {
return batch == _batch || batch == 1;
}
void transpose() {
std::swap(rows, cols);
std::swap(row_stride, col_stride);
}
ptrdiff_t ld() const {
return row_stride == 1 ? col_stride : row_stride;
}
};
enum class MatrixLayout : char {
COL_MAJOR,
ROW_MAJOR,
};
class MatmulInfo {
MatmulInfo() = default;
public:
BlasMatrix a_matrix;
BlasMatrix b_matrix;
BlasMatrix c_matrix;
size_t m, n, k, batch;
bool is_transed;
static utils::Result<MatmulInfo> create(
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
MatrixLayout layout) {
auto a_matrix = BlasMatrix::create(a_desc);
CHECK_RESULT(a_matrix);
auto b_matrix = BlasMatrix::create(b_desc);
CHECK_RESULT(b_matrix);
auto c_matrix = BlasMatrix::create(c_desc);
CHECK_RESULT(c_matrix);
if (c_matrix->rows != a_matrix->rows || c_matrix->cols != b_matrix->cols || a_matrix->cols != b_matrix->rows) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
auto batch = c_matrix->batch;
if (!a_matrix->match_batch(batch) || !b_matrix->match_batch(batch)) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
auto is_transed = false;
if ((layout == MatrixLayout::COL_MAJOR && c_matrix->col_stride == 1)
|| (layout == MatrixLayout::ROW_MAJOR && c_matrix->row_stride == 1)) {
c_matrix->transpose();
b_matrix->transpose();
a_matrix->transpose();
std::swap(a_matrix, b_matrix);
is_transed = true;
}
auto m = c_matrix->rows;
auto n = c_matrix->cols;
auto k = a_matrix->cols;
return utils::Result<MatmulInfo>(MatmulInfo{
a_matrix.take(),
b_matrix.take(),
c_matrix.take(),
m,
n,
k,
batch,
is_transed});
}
};
} // namespace op::gemm
#endif // __GEMM_INFO_H__
...@@ -27,14 +27,11 @@ infiniStatus_t Descriptor::create( ...@@ -27,14 +27,11 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
infiniStatus_t status; auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::ROW_MAJOR);
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::ROW_MAJOR); CHECK_RESULT(result);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
dtype, info, 0, dtype, result.take(), 0,
new Opaque{handle->internal()}, new Opaque{handle->internal()},
handle->device, handle->device_id); handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
......
...@@ -25,14 +25,11 @@ infiniStatus_t Descriptor::create( ...@@ -25,14 +25,11 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
infiniStatus_t status; auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::COL_MAJOR); CHECK_RESULT(result);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
dtype, info, 0, dtype, result.take(), 0,
new Opaque{handle->internal()}, new Opaque{handle->internal()},
handle->device, handle->device_id); handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
......
...@@ -70,9 +70,9 @@ infiniopGetGemmWorkspaceSize( ...@@ -70,9 +70,9 @@ infiniopGetGemmWorkspaceSize(
infiniopGemmDescriptor_t desc, infiniopGemmDescriptor_t desc,
size_t *size) { size_t *size) {
#define GET(CASE, NAMESPACE) \ #define GET(CASE, NAMESPACE) \
case CASE: \ case CASE: \
*size = reinterpret_cast<const op::gemm::NAMESPACE::Descriptor *>(desc)->workspace_size; \ *size = reinterpret_cast<const op::gemm::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS return INFINI_STATUS_SUCCESS
switch (desc->device_type) { switch (desc->device_type) {
......
...@@ -27,14 +27,12 @@ infiniStatus_t Descriptor::create( ...@@ -27,14 +27,12 @@ infiniStatus_t Descriptor::create(
auto dst_strides = y_desc->strides().data(); auto dst_strides = y_desc->strides().data();
auto src_strides = x_desc->strides().data(); auto src_strides = x_desc->strides().data();
auto element_size = infiniSizeOf(dtype); auto element_size = infiniSizeOf(dtype);
auto meta = utils::RearrangeMeta::create(shape, dst_strides, src_strides, ndim, element_size);
if (!meta) { auto result = utils::RearrangeMeta::create(shape, dst_strides, src_strides, ndim, element_size);
return INFINI_STATUS_BAD_TENSOR_STRIDES; CHECK_RESULT(result);
}
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
std::move(*meta), result.take(),
nullptr, nullptr,
handle->device, handle->device,
handle->device_id); handle->device_id);
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "../../../reduce/cpu/reduce.h" #include "../../../reduce/cpu/reduce.h"
namespace op::rms_norm::cpu { namespace op::rms_norm::cpu {
Descriptor::~Descriptor() {} Descriptor::~Descriptor() {}
infiniStatus_t Descriptor::create( infiniStatus_t Descriptor::create(
...@@ -12,9 +13,9 @@ infiniStatus_t Descriptor::create( ...@@ -12,9 +13,9 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t x_desc, infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc, infiniopTensorDescriptor_t w_desc,
float epsilon) { float epsilon) {
RMSNormInfo info; auto result = RMSNormInfo::create(y_desc, x_desc, w_desc, epsilon);
CHECK_STATUS(createRMSNormInfo(&info, y_desc, x_desc, w_desc, epsilon)); CHECK_RESULT(result);
*desc_ptr = new Descriptor(nullptr, info, 0, handle->device, handle->device_id); *desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
...@@ -68,9 +69,10 @@ infiniStatus_t rmsnormF16(const RMSNormInfo *info, fp16_t *y, const fp16_t *x, c ...@@ -68,9 +69,10 @@ infiniStatus_t rmsnormF16(const RMSNormInfo *info, fp16_t *y, const fp16_t *x, c
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, infiniStatus_t Descriptor::calculate(
void *y, const void *x, const void *w, void *workspace, size_t workspace_size,
void *stream) { void *y, const void *x, const void *w,
void *stream) const {
if (_info.atype == INFINI_DTYPE_F16) { if (_info.atype == INFINI_DTYPE_F16) {
if (_info.wtype == INFINI_DTYPE_F16) { if (_info.wtype == INFINI_DTYPE_F16) {
CHECK_STATUS(rmsnormF16(&_info, (fp16_t *)y, (const fp16_t *)x, (const fp16_t *)w)); CHECK_STATUS(rmsnormF16(&_info, (fp16_t *)y, (const fp16_t *)x, (const fp16_t *)w));
......
#ifndef __RMS_NORM_INFO_H__
#define __RMS_NORM_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <vector>
namespace op::rms_norm {
class RMSNormInfo {
RMSNormInfo() = default;
public:
infiniDtype_t wtype;
infiniDtype_t atype;
float epsilon;
std::vector<size_t> shape;
std::vector<ptrdiff_t> y_strides;
std::vector<ptrdiff_t> x_strides;
size_t ndim() const { return shape.size(); }
size_t dim() const { return shape[ndim() - 1]; }
static utils::Result<RMSNormInfo> create(
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc,
float epsilon) {
auto atype = y_desc->dtype();
auto wtype = w_desc->dtype();
if (x_desc->dtype() != atype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (atype == INFINI_DTYPE_F16) {
if (wtype != INFINI_DTYPE_F16 && wtype != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (atype == INFINI_DTYPE_F32 || atype == INFINI_DTYPE_F64) {
if (atype != wtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (y_desc->ndim() != 2 || x_desc->ndim() != 2 || w_desc->ndim() != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t batch = y_desc->shape()[0];
size_t dim = y_desc->shape()[1];
if (x_desc->shape()[0] != batch || x_desc->shape()[1] != dim || w_desc->shape()[0] != dim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (w_desc->stride(0) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
if (x_desc->stride(1) != 1 || y_desc->stride(1) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
return utils::Result<RMSNormInfo>(RMSNormInfo{
wtype,
atype,
epsilon,
y_desc->shape(),
y_desc->strides(),
x_desc->strides(),
});
}
};
} // namespace op::rms_norm
#endif // __RMS_NORM_INFO_H__
#ifndef RMS_NORM_H #ifndef RMS_NORM_H
#define RMS_NORM_H #define RMS_NORM_H
#include "../../operator.h"
#include "../../tensor.h"
#include <vector>
struct RMSNormInfo {
infiniDtype_t wtype;
infiniDtype_t atype;
float epsilon;
std::vector<size_t> shape;
std::vector<ptrdiff_t> y_strides;
std::vector<ptrdiff_t> x_strides;
size_t ndim() { return shape.size(); }
size_t dim() { return shape[ndim() - 1]; }
};
inline infiniStatus_t createRMSNormInfo(RMSNormInfo *info, infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc,
float epsilon) {
auto atype = y_desc->dtype();
auto wtype = w_desc->dtype();
if (x_desc->dtype() != atype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (atype == INFINI_DTYPE_F16) {
if (wtype != INFINI_DTYPE_F16 && wtype != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (atype == INFINI_DTYPE_F32 || atype == INFINI_DTYPE_F64) {
if (atype != wtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
info->wtype = wtype;
info->atype = atype;
info->epsilon = epsilon;
if (y_desc->ndim() != 2 || x_desc->ndim() != 2 || w_desc->ndim() != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t batch = y_desc->shape()[0]; #include "../../operator.h"
size_t dim = y_desc->shape()[1]; #include "info.h"
if (x_desc->shape()[0] != batch || x_desc->shape()[1] != dim || w_desc->shape()[0] != dim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE; #define DESCRIPTOR(NAMESPACE) \
} \
namespace op::rms_norm::NAMESPACE { \
if (w_desc->stride(0) != 1) { class Descriptor final : public InfiniopDescriptor { \
return INFINI_STATUS_BAD_TENSOR_STRIDES; struct Opaque; \
} Opaque *_opaque; \
RMSNormInfo _info; \
if (x_desc->stride(1) != 1 || y_desc->stride(1) != 1) { size_t _workspace_size; \
return INFINI_STATUS_BAD_TENSOR_STRIDES; \
} Descriptor( \
Opaque *opaque, \
info->shape = std::move(y_desc->shape()); RMSNormInfo info, \
info->y_strides = std::move(y_desc->strides()); size_t workspace_size, \
info->x_strides = std::move(x_desc->strides()); infiniDevice_t device_type, \
int device_id) \
return INFINI_STATUS_SUCCESS; : InfiniopDescriptor{device_type, device_id}, \
} _opaque(opaque), \
_info(info), \
#define DESCRIPTOR(NAMESPACE) \ _workspace_size(workspace_size) {} \
namespace op::rms_norm::NAMESPACE { \ \
class Descriptor final : public InfiniopDescriptor { \ public: \
struct Opaque; \ ~Descriptor(); \
Opaque *_opaque; \ \
RMSNormInfo _info; \ size_t workspaceSize() const { return _workspace_size; } \
size_t _workspace_size; \ \
\ static infiniStatus_t create( \
Descriptor( \ infiniopHandle_t handle, \
Opaque *opaque, \ Descriptor **desc_ptr, \
RMSNormInfo info, \ infiniopTensorDescriptor_t y_desc, \
size_t workspace_size, \ infiniopTensorDescriptor_t x_desc, \
infiniDevice_t device_type, \ infiniopTensorDescriptor_t w_desc, \
int device_id) : InfiniopDescriptor{device_type, device_id}, \ float epsilon); \
_opaque(opaque), \ \
_info(info), \ infiniStatus_t calculate( \
_workspace_size(workspace_size) {} \ void *workspace, size_t workspace_size, \
\ void *y, \
public: \ const void *x, \
~Descriptor(); \ const void *w, \
size_t workspaceSize() const { return _workspace_size; } \ void *stream) const; \
static infiniStatus_t create( \ }; \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t x_desc, \
infiniopTensorDescriptor_t w_desc, \
float epsilon); \
infiniStatus_t calculate(void *workspace, size_t workspace_size, \
void *y, const void *x, const void *w, void *stream); \
}; \
} }
#endif // RMS_NORM_H #endif // RMS_NORM_H
#ifndef INFINIUTILS_H #ifndef INFINIUTILS_H
#define INFINIUTILS_H #define INFINIUTILS_H
#include "infinicore.h"
#include "utils/check.h"
#include "utils/custom_types.h" #include "utils/custom_types.h"
#include "utils/rearrange.h" #include "utils/rearrange.h"
......
...@@ -13,7 +13,7 @@ namespace utils { ...@@ -13,7 +13,7 @@ namespace utils {
RearrangeMeta::RearrangeMeta(std::vector<ptrdiff_t> meta) RearrangeMeta::RearrangeMeta(std::vector<ptrdiff_t> meta)
: _meta(std::move(meta)) {} : _meta(std::move(meta)) {}
std::optional<RearrangeMeta> RearrangeMeta::create( Result<RearrangeMeta> RearrangeMeta::create(
const size_t *shape, const size_t *shape,
const ptrdiff_t *dst_strides_, const ptrdiff_t *dst_strides_,
const ptrdiff_t *src_strides_, const ptrdiff_t *src_strides_,
...@@ -32,7 +32,9 @@ std::optional<RearrangeMeta> RearrangeMeta::create( ...@@ -32,7 +32,9 @@ std::optional<RearrangeMeta> RearrangeMeta::create(
// 剔除初始的 1 长维度 // 剔除初始的 1 长维度
if (shape[i] != 1) { if (shape[i] != 1) {
auto sd = dst_strides_[i] * unit, ss = src_strides_[i] * unit; auto sd = dst_strides_[i] * unit, ss = src_strides_[i] * unit;
// assert (sd != 0) if (sd == 0) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
dims.push_back(Dim{shape[i], sd, ss}); dims.push_back(Dim{shape[i], sd, ss});
} }
} }
...@@ -81,7 +83,7 @@ std::optional<RearrangeMeta> RearrangeMeta::create( ...@@ -81,7 +83,7 @@ std::optional<RearrangeMeta> RearrangeMeta::create(
for (ptrdiff_t i = ndim; i > 0; --i) { for (ptrdiff_t i = ndim; i > 0; --i) {
meta[1 + i - 1] *= meta[1 + i]; meta[1 + i - 1] *= meta[1 + i];
} }
return {RearrangeMeta(std::move(meta))}; return Result<RearrangeMeta>(meta);
} }
size_t RearrangeMeta::ndim() const { return (_meta.size() - 2) / 3; } size_t RearrangeMeta::ndim() const { return (_meta.size() - 2) / 3; }
......
#ifndef __INFINIUTILS_REARRANGE_H__ #ifndef __INFINIUTILS_REARRANGE_H__
#define __INFINIUTILS_REARRANGE_H__ #define __INFINIUTILS_REARRANGE_H__
#include <optional> #include "result.hpp"
#include <stddef.h> #include <cstddef>
#include <vector> #include <vector>
namespace utils { namespace utils {
...@@ -12,7 +12,7 @@ class RearrangeMeta { ...@@ -12,7 +12,7 @@ class RearrangeMeta {
RearrangeMeta(std::vector<ptrdiff_t>); RearrangeMeta(std::vector<ptrdiff_t>);
public: public:
static std::optional<RearrangeMeta> create( static Result<RearrangeMeta> create(
const size_t *shape, const size_t *shape,
const ptrdiff_t *dst_strides, const ptrdiff_t *dst_strides,
const ptrdiff_t *src_strides, const ptrdiff_t *src_strides,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment