Commit fab5ed70 authored by YdrMaster's avatar YdrMaster
Browse files

issue/121/refactor: 用 Result 修改 MatMulInfo 构造流程


Signed-off-by: default avatarYdrMaster <ydrml@hotmail.com>
parent a0eab6bf
......@@ -38,14 +38,12 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
infiniStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::ROW_MAJOR);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::ROW_MAJOR);
CHECK_RESULT(result);
auto info = result.take();
auto c = new aclnnTensorDescriptor(toAclDataType(c_desc->dtype()),
{static_cast<int64_t>(info.c_matrix.rows), static_cast<int64_t>(info.c_matrix.cols)},
{static_cast<int64_t>(info.m), static_cast<int64_t>(info.n)},
{info.c_matrix.row_stride, info.c_matrix.col_stride});
auto a = new aclnnTensorDescriptor(toAclDataType(a_desc->dtype()),
{static_cast<int64_t>(info.a_matrix.rows), static_cast<int64_t>(info.a_matrix.cols)},
......
......@@ -71,11 +71,9 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
infiniStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::ROW_MAJOR);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::ROW_MAJOR);
CHECK_RESULT(result);
auto info = result.take();
cnnlTensorDescriptor_t a, b, c;
CHECK_BANG(cnnlCreateTensorDescriptor(&a));
......
......@@ -18,14 +18,11 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
infiniStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::COL_MAJOR);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(
dtype, info, 0,
dtype, result.take(), 0,
nullptr,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
......
......@@ -24,14 +24,11 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
infiniStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::COL_MAJOR);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(
dtype, info, 0,
dtype, result.take(), 0,
new Opaque{handle->internal()},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
......
......@@ -2,7 +2,7 @@
#define __GEMM_H__
#include "../../operator.h"
#include "blas.h"
#include "matmul_info.h"
/**
* # 关于 `DESCRIPTOR(NAMESPACE)` 和 `struct Opaque;` 的说明
......
......@@ -27,14 +27,11 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
infiniStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::ROW_MAJOR);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::ROW_MAJOR);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(
dtype, info, 0,
dtype, result.take(), 0,
new Opaque{handle->internal()},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
......
......@@ -25,14 +25,11 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
infiniStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::COL_MAJOR);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(
dtype, info, 0,
dtype, result.take(), 0,
new Opaque{handle->internal()},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
......
#ifndef __BLAS_H__
#define __BLAS_H__
#include "../../../utils.h"
#include "../../operator.h"
#include "../../tensor.h"
#include <algorithm>
namespace op::gemm {
struct BlasMatrix {
class BlasMatrix {
BlasMatrix() = default;
public:
size_t ndim;
size_t batch;
ptrdiff_t stride;
......@@ -16,36 +20,34 @@ struct BlasMatrix {
ptrdiff_t row_stride;
ptrdiff_t col_stride;
BlasMatrix() = default;
static utils::Result<BlasMatrix> create(infiniopTensorDescriptor_t layout) {
BlasMatrix ans;
BlasMatrix(infiniopTensorDescriptor_t layout, infiniStatus_t *status) {
if (layout->ndim() == 2) {
ndim = 2;
batch = 1;
stride = 0;
rows = layout->dim(0);
cols = layout->dim(1);
row_stride = layout->stride(0);
col_stride = layout->stride(1);
ans.ndim = 2;
ans.batch = 1;
ans.stride = 0;
ans.rows = layout->dim(0);
ans.cols = layout->dim(1);
ans.row_stride = layout->stride(0);
ans.col_stride = layout->stride(1);
} else if (layout->ndim() == 3) {
ndim = 3;
batch = layout->dim(0);
stride = batch == 1 ? 0 : layout->stride(0);
rows = layout->dim(1);
cols = layout->dim(2);
row_stride = layout->stride(1);
col_stride = layout->stride(2);
ans.ndim = 3;
ans.batch = layout->dim(0);
ans.stride = ans.batch == 1 ? 0 : layout->stride(0);
ans.rows = layout->dim(1);
ans.cols = layout->dim(2);
ans.row_stride = layout->stride(1);
ans.col_stride = layout->stride(2);
} else {
*status = INFINI_STATUS_BAD_TENSOR_SHAPE;
return;
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (row_stride != 1 && col_stride != 1) {
*status = INFINI_STATUS_BAD_TENSOR_STRIDES;
return;
if (ans.row_stride != 1 && ans.col_stride != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
*status = INFINI_STATUS_SUCCESS;
return utils::Result<BlasMatrix>(ans);
}
bool match_batch(size_t _batch) const {
......@@ -67,56 +69,64 @@ enum class MatrixLayout : char {
ROW_MAJOR,
};
struct MatmulInfo {
class MatmulInfo {
MatmulInfo() = default;
public:
BlasMatrix a_matrix;
BlasMatrix b_matrix;
BlasMatrix c_matrix;
size_t m, n, k, batch;
bool is_transed;
bool is_transed = false;
MatmulInfo(infiniopTensorDescriptor_t c_desc,
static utils::Result<MatmulInfo> create(
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniStatus_t *status,
MatrixLayout layout) {
a_matrix = BlasMatrix(a_desc, status);
if (*status != INFINI_STATUS_SUCCESS) {
return;
}
b_matrix = BlasMatrix(b_desc, status);
if (*status != INFINI_STATUS_SUCCESS) {
return;
}
c_matrix = BlasMatrix(c_desc, status);
if (*status != INFINI_STATUS_SUCCESS) {
return;
}
if (c_matrix.rows != a_matrix.rows || c_matrix.cols != b_matrix.cols || a_matrix.cols != b_matrix.rows) {
*status = INFINI_STATUS_BAD_TENSOR_SHAPE;
return;
auto a_matrix = BlasMatrix::create(a_desc);
CHECK_RESULT(a_matrix);
auto b_matrix = BlasMatrix::create(b_desc);
CHECK_RESULT(b_matrix);
auto c_matrix = BlasMatrix::create(c_desc);
CHECK_RESULT(c_matrix);
if (c_matrix->rows != a_matrix->rows || c_matrix->cols != b_matrix->cols || a_matrix->cols != b_matrix->rows) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
batch = c_matrix.batch;
if (!a_matrix.match_batch(batch) || !b_matrix.match_batch(batch)) {
*status = INFINI_STATUS_BAD_TENSOR_SHAPE;
return;
auto batch = c_matrix->batch;
if (!a_matrix->match_batch(batch) || !b_matrix->match_batch(batch)) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if ((layout == MatrixLayout::COL_MAJOR && c_matrix.col_stride == 1)
|| (layout == MatrixLayout::ROW_MAJOR && c_matrix.row_stride == 1)) {
c_matrix.transpose();
b_matrix.transpose();
a_matrix.transpose();
auto is_transed = false;
if ((layout == MatrixLayout::COL_MAJOR && c_matrix->col_stride == 1)
|| (layout == MatrixLayout::ROW_MAJOR && c_matrix->row_stride == 1)) {
c_matrix->transpose();
b_matrix->transpose();
a_matrix->transpose();
std::swap(a_matrix, b_matrix);
is_transed = true;
}
m = c_matrix.rows;
n = c_matrix.cols;
k = a_matrix.cols;
auto m = c_matrix->rows;
auto n = c_matrix->cols;
auto k = a_matrix->cols;
return utils::Result<MatmulInfo>(MatmulInfo{
a_matrix.take(),
b_matrix.take(),
c_matrix.take(),
m,
n,
k,
batch,
is_transed});
}
};
......
......@@ -27,14 +27,12 @@ infiniStatus_t Descriptor::create(
auto dst_strides = y_desc->strides().data();
auto src_strides = x_desc->strides().data();
auto element_size = infiniSizeOf(dtype);
auto meta = utils::RearrangeMeta::create(shape, dst_strides, src_strides, ndim, element_size);
if (!meta) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
auto result = utils::RearrangeMeta::create(shape, dst_strides, src_strides, ndim, element_size);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(
std::move(*meta),
result.take(),
nullptr,
handle->device,
handle->device_id);
......
......@@ -2,7 +2,6 @@
#define INFINIUTILS_H
#include "infinicore.h"
#include "utils/check.h"
#include "utils/custom_types.h"
#include "utils/rearrange.h"
......
......@@ -33,7 +33,7 @@ Result<RearrangeMeta> RearrangeMeta::create(
if (shape[i] != 1) {
auto sd = dst_strides_[i] * unit, ss = src_strides_[i] * unit;
if (sd == 0) {
return Result<RearrangeMeta>(INFINI_STATUS_BAD_TENSOR_STRIDES);
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
dims.push_back(Dim{shape[i], sd, ss});
}
......
......@@ -5,6 +5,11 @@
#include <infinicore.h>
#include <variant>
#define CHECK_RESULT(RESULT) \
if (!RESULT) { \
return RESULT.status(); \
}
namespace utils {
template <typename T, typename = std::enable_if_t<!std::is_same_v<T, infiniStatus_t>>>
......@@ -13,7 +18,7 @@ class Result {
public:
explicit Result(T value) : _result(std::move(value)) {}
explicit Result(infiniStatus_t status) : _result(status) {
Result(infiniStatus_t status) : _result(status) {
if (status == INFINI_STATUS_SUCCESS) {
std::cerr << "Warning: Result created with success status but value is not set." << std::endl;
std::abort();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment