Commit fab5ed70 authored by YdrMaster's avatar YdrMaster
Browse files

issue/121/refactor: 用 Result 修改 MatMulInfo 构造流程


Signed-off-by: default avatarYdrMaster <ydrml@hotmail.com>
parent a0eab6bf
...@@ -38,14 +38,12 @@ infiniStatus_t Descriptor::create( ...@@ -38,14 +38,12 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
infiniStatus_t status; auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::ROW_MAJOR);
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::ROW_MAJOR); CHECK_RESULT(result);
if (status != INFINI_STATUS_SUCCESS) { auto info = result.take();
return status;
}
auto c = new aclnnTensorDescriptor(toAclDataType(c_desc->dtype()), auto c = new aclnnTensorDescriptor(toAclDataType(c_desc->dtype()),
{static_cast<int64_t>(info.c_matrix.rows), static_cast<int64_t>(info.c_matrix.cols)}, {static_cast<int64_t>(info.m), static_cast<int64_t>(info.n)},
{info.c_matrix.row_stride, info.c_matrix.col_stride}); {info.c_matrix.row_stride, info.c_matrix.col_stride});
auto a = new aclnnTensorDescriptor(toAclDataType(a_desc->dtype()), auto a = new aclnnTensorDescriptor(toAclDataType(a_desc->dtype()),
{static_cast<int64_t>(info.a_matrix.rows), static_cast<int64_t>(info.a_matrix.cols)}, {static_cast<int64_t>(info.a_matrix.rows), static_cast<int64_t>(info.a_matrix.cols)},
......
...@@ -71,11 +71,9 @@ infiniStatus_t Descriptor::create( ...@@ -71,11 +71,9 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
infiniStatus_t status; auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::ROW_MAJOR);
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::ROW_MAJOR); CHECK_RESULT(result);
if (status != INFINI_STATUS_SUCCESS) { auto info = result.take();
return status;
}
cnnlTensorDescriptor_t a, b, c; cnnlTensorDescriptor_t a, b, c;
CHECK_BANG(cnnlCreateTensorDescriptor(&a)); CHECK_BANG(cnnlCreateTensorDescriptor(&a));
......
...@@ -18,14 +18,11 @@ infiniStatus_t Descriptor::create( ...@@ -18,14 +18,11 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
infiniStatus_t status; auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::COL_MAJOR); CHECK_RESULT(result);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
dtype, info, 0, dtype, result.take(), 0,
nullptr, nullptr,
handle->device, handle->device_id); handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
......
...@@ -24,14 +24,11 @@ infiniStatus_t Descriptor::create( ...@@ -24,14 +24,11 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
infiniStatus_t status; auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::COL_MAJOR); CHECK_RESULT(result);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
dtype, info, 0, dtype, result.take(), 0,
new Opaque{handle->internal()}, new Opaque{handle->internal()},
handle->device, handle->device_id); handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define __GEMM_H__ #define __GEMM_H__
#include "../../operator.h" #include "../../operator.h"
#include "blas.h" #include "matmul_info.h"
/** /**
* # 关于 `DESCRIPTOR(NAMESPACE)` 和 `struct Opaque;` 的说明 * # 关于 `DESCRIPTOR(NAMESPACE)` 和 `struct Opaque;` 的说明
......
...@@ -27,14 +27,11 @@ infiniStatus_t Descriptor::create( ...@@ -27,14 +27,11 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
infiniStatus_t status; auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::ROW_MAJOR);
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::ROW_MAJOR); CHECK_RESULT(result);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
dtype, info, 0, dtype, result.take(), 0,
new Opaque{handle->internal()}, new Opaque{handle->internal()},
handle->device, handle->device_id); handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
......
...@@ -25,14 +25,11 @@ infiniStatus_t Descriptor::create( ...@@ -25,14 +25,11 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
infiniStatus_t status; auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::COL_MAJOR); CHECK_RESULT(result);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
dtype, info, 0, dtype, result.take(), 0,
new Opaque{handle->internal()}, new Opaque{handle->internal()},
handle->device, handle->device_id); handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
......
#ifndef __BLAS_H__ #ifndef __BLAS_H__
#define __BLAS_H__ #define __BLAS_H__
#include "../../../utils.h"
#include "../../operator.h" #include "../../operator.h"
#include "../../tensor.h" #include "../../tensor.h"
#include <algorithm> #include <algorithm>
namespace op::gemm { namespace op::gemm {
struct BlasMatrix { class BlasMatrix {
BlasMatrix() = default;
public:
size_t ndim; size_t ndim;
size_t batch; size_t batch;
ptrdiff_t stride; ptrdiff_t stride;
...@@ -16,36 +20,34 @@ struct BlasMatrix { ...@@ -16,36 +20,34 @@ struct BlasMatrix {
ptrdiff_t row_stride; ptrdiff_t row_stride;
ptrdiff_t col_stride; ptrdiff_t col_stride;
BlasMatrix() = default; static utils::Result<BlasMatrix> create(infiniopTensorDescriptor_t layout) {
BlasMatrix ans;
BlasMatrix(infiniopTensorDescriptor_t layout, infiniStatus_t *status) {
if (layout->ndim() == 2) { if (layout->ndim() == 2) {
ndim = 2; ans.ndim = 2;
batch = 1; ans.batch = 1;
stride = 0; ans.stride = 0;
rows = layout->dim(0); ans.rows = layout->dim(0);
cols = layout->dim(1); ans.cols = layout->dim(1);
row_stride = layout->stride(0); ans.row_stride = layout->stride(0);
col_stride = layout->stride(1); ans.col_stride = layout->stride(1);
} else if (layout->ndim() == 3) { } else if (layout->ndim() == 3) {
ndim = 3; ans.ndim = 3;
batch = layout->dim(0); ans.batch = layout->dim(0);
stride = batch == 1 ? 0 : layout->stride(0); ans.stride = ans.batch == 1 ? 0 : layout->stride(0);
rows = layout->dim(1); ans.rows = layout->dim(1);
cols = layout->dim(2); ans.cols = layout->dim(2);
row_stride = layout->stride(1); ans.row_stride = layout->stride(1);
col_stride = layout->stride(2); ans.col_stride = layout->stride(2);
} else { } else {
*status = INFINI_STATUS_BAD_TENSOR_SHAPE; return INFINI_STATUS_BAD_TENSOR_SHAPE;
return;
} }
if (row_stride != 1 && col_stride != 1) { if (ans.row_stride != 1 && ans.col_stride != 1) {
*status = INFINI_STATUS_BAD_TENSOR_STRIDES; return INFINI_STATUS_BAD_TENSOR_STRIDES;
return;
} }
*status = INFINI_STATUS_SUCCESS; return utils::Result<BlasMatrix>(ans);
} }
bool match_batch(size_t _batch) const { bool match_batch(size_t _batch) const {
...@@ -67,56 +69,64 @@ enum class MatrixLayout : char { ...@@ -67,56 +69,64 @@ enum class MatrixLayout : char {
ROW_MAJOR, ROW_MAJOR,
}; };
struct MatmulInfo { class MatmulInfo {
MatmulInfo() = default;
public:
BlasMatrix a_matrix; BlasMatrix a_matrix;
BlasMatrix b_matrix; BlasMatrix b_matrix;
BlasMatrix c_matrix; BlasMatrix c_matrix;
size_t m, n, k, batch; size_t m, n, k, batch;
bool is_transed;
bool is_transed = false; static utils::Result<MatmulInfo> create(
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
MatrixLayout layout) {
MatmulInfo(infiniopTensorDescriptor_t c_desc, auto a_matrix = BlasMatrix::create(a_desc);
infiniopTensorDescriptor_t a_desc, CHECK_RESULT(a_matrix);
infiniopTensorDescriptor_t b_desc,
infiniStatus_t *status, auto b_matrix = BlasMatrix::create(b_desc);
MatrixLayout layout) { CHECK_RESULT(b_matrix);
a_matrix = BlasMatrix(a_desc, status);
if (*status != INFINI_STATUS_SUCCESS) { auto c_matrix = BlasMatrix::create(c_desc);
return; CHECK_RESULT(c_matrix);
}
b_matrix = BlasMatrix(b_desc, status);
if (*status != INFINI_STATUS_SUCCESS) {
return;
}
c_matrix = BlasMatrix(c_desc, status);
if (*status != INFINI_STATUS_SUCCESS) {
return;
}
if (c_matrix.rows != a_matrix.rows || c_matrix.cols != b_matrix.cols || a_matrix.cols != b_matrix.rows) { if (c_matrix->rows != a_matrix->rows || c_matrix->cols != b_matrix->cols || a_matrix->cols != b_matrix->rows) {
*status = INFINI_STATUS_BAD_TENSOR_SHAPE; return INFINI_STATUS_BAD_TENSOR_SHAPE;
return;
} }
batch = c_matrix.batch; auto batch = c_matrix->batch;
if (!a_matrix.match_batch(batch) || !b_matrix.match_batch(batch)) { if (!a_matrix->match_batch(batch) || !b_matrix->match_batch(batch)) {
*status = INFINI_STATUS_BAD_TENSOR_SHAPE; return INFINI_STATUS_BAD_TENSOR_SHAPE;
return;
} }
if ((layout == MatrixLayout::COL_MAJOR && c_matrix.col_stride == 1) auto is_transed = false;
|| (layout == MatrixLayout::ROW_MAJOR && c_matrix.row_stride == 1)) { if ((layout == MatrixLayout::COL_MAJOR && c_matrix->col_stride == 1)
c_matrix.transpose(); || (layout == MatrixLayout::ROW_MAJOR && c_matrix->row_stride == 1)) {
b_matrix.transpose(); c_matrix->transpose();
a_matrix.transpose(); b_matrix->transpose();
a_matrix->transpose();
std::swap(a_matrix, b_matrix); std::swap(a_matrix, b_matrix);
is_transed = true; is_transed = true;
} }
m = c_matrix.rows; auto m = c_matrix->rows;
n = c_matrix.cols; auto n = c_matrix->cols;
k = a_matrix.cols; auto k = a_matrix->cols;
return utils::Result<MatmulInfo>(MatmulInfo{
a_matrix.take(),
b_matrix.take(),
c_matrix.take(),
m,
n,
k,
batch,
is_transed});
} }
}; };
......
...@@ -27,14 +27,12 @@ infiniStatus_t Descriptor::create( ...@@ -27,14 +27,12 @@ infiniStatus_t Descriptor::create(
auto dst_strides = y_desc->strides().data(); auto dst_strides = y_desc->strides().data();
auto src_strides = x_desc->strides().data(); auto src_strides = x_desc->strides().data();
auto element_size = infiniSizeOf(dtype); auto element_size = infiniSizeOf(dtype);
auto meta = utils::RearrangeMeta::create(shape, dst_strides, src_strides, ndim, element_size);
if (!meta) { auto result = utils::RearrangeMeta::create(shape, dst_strides, src_strides, ndim, element_size);
return INFINI_STATUS_BAD_TENSOR_STRIDES; CHECK_RESULT(result);
}
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
std::move(*meta), result.take(),
nullptr, nullptr,
handle->device, handle->device,
handle->device_id); handle->device_id);
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
#define INFINIUTILS_H #define INFINIUTILS_H
#include "infinicore.h" #include "infinicore.h"
#include "utils/check.h"
#include "utils/custom_types.h" #include "utils/custom_types.h"
#include "utils/rearrange.h" #include "utils/rearrange.h"
......
...@@ -33,7 +33,7 @@ Result<RearrangeMeta> RearrangeMeta::create( ...@@ -33,7 +33,7 @@ Result<RearrangeMeta> RearrangeMeta::create(
if (shape[i] != 1) { if (shape[i] != 1) {
auto sd = dst_strides_[i] * unit, ss = src_strides_[i] * unit; auto sd = dst_strides_[i] * unit, ss = src_strides_[i] * unit;
if (sd == 0) { if (sd == 0) {
return Result<RearrangeMeta>(INFINI_STATUS_BAD_TENSOR_STRIDES); return INFINI_STATUS_BAD_TENSOR_STRIDES;
} }
dims.push_back(Dim{shape[i], sd, ss}); dims.push_back(Dim{shape[i], sd, ss});
} }
......
...@@ -5,6 +5,11 @@ ...@@ -5,6 +5,11 @@
#include <infinicore.h> #include <infinicore.h>
#include <variant> #include <variant>
#define CHECK_RESULT(RESULT) \
if (!RESULT) { \
return RESULT.status(); \
}
namespace utils { namespace utils {
template <typename T, typename = std::enable_if_t<!std::is_same_v<T, infiniStatus_t>>> template <typename T, typename = std::enable_if_t<!std::is_same_v<T, infiniStatus_t>>>
...@@ -13,7 +18,7 @@ class Result { ...@@ -13,7 +18,7 @@ class Result {
public: public:
explicit Result(T value) : _result(std::move(value)) {} explicit Result(T value) : _result(std::move(value)) {}
explicit Result(infiniStatus_t status) : _result(status) { Result(infiniStatus_t status) : _result(status) {
if (status == INFINI_STATUS_SUCCESS) { if (status == INFINI_STATUS_SUCCESS) {
std::cerr << "Warning: Result created with success status but value is not set." << std::endl; std::cerr << "Warning: Result created with success status but value is not set." << std::endl;
std::abort(); std::abort();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment