Commit f7137096 authored by YdrMaster's avatar YdrMaster
Browse files

issue/63/refactor: 重构 Matmul 所有实现,添加命名空间


Signed-off-by: default avatarYdrMaster <ydrml@hotmail.com>
parent 8e34901e
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define __INFINIOP__COMMON_CPU_H__ #define __INFINIOP__COMMON_CPU_H__
#include <cmath> #include <cmath>
#include <cstddef>
#include <cstdint> #include <cstdint>
#include <cstring> #include <cstring>
#include <vector> #include <vector>
......
...@@ -48,7 +48,7 @@ struct InfiniopCudaHandle { ...@@ -48,7 +48,7 @@ struct InfiniopCudaHandle {
}; };
template <typename T> template <typename T>
void use_cublas(std::shared_ptr<Pool<cublasHandle_t>> cublas_handle_pool, int device_id, cudaStream_t stream, T const &f) { void use_cublas(std::shared_ptr<Pool<cublasHandle_t>> &cublas_handle_pool, cudaStream_t stream, T const &f) {
auto handle = cublas_handle_pool->pop(); auto handle = cublas_handle_pool->pop();
if (!handle) { if (!handle) {
cublasCreate(&(*handle)); cublasCreate(&(*handle));
......
#ifndef __ACLNN_MATMUL_H__
#define __ACLNN_MATMUL_H__
#include "../../../devices/ascend/tensor_aclnn.h"
#include "../../utils.h"
#include "../blas.h"
#include "matmul_aclnn_api.h"
#include <acl/acl_base.h>
#include <aclnn/acl_meta.h>
#include <aclnnop/aclnn_matmul.h>
#include <aclnnop/level2/aclnn_gemm.h>
struct InfiniopMatmulAclnnDescriptor {
infiniDevice_t device;
int device_id;
aclOpExecutor *executor;
MatmulInfo *info;
infiniDtype_t dtype;
aclnnTensorDescriptor_t cDesc, aDesc, bDesc;
// cubeMathType
// see doc:
// https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC3alpha002/apiref/appdevgapi/context/aclnnBatchMatMul.md
int8_t mt;
size_t workspaceSize;
InfiniopMatmulAclnnDescriptor(infiniDevice_t _device);
};
#endif
#ifndef __INFINIOP_MATMUL_ACLNN_API_H__
#define __INFINIOP_MATMUL_ACLNN_API_H__
#include "../../../devices/ascend/ascend_handle.h"
#include "infiniop/operator.h"
struct InfiniopMatmulAclnnDescriptor;
typedef struct InfiniopMatmulAclnnDescriptor *MatmulAclnnDescriptor_t;
infiniopStatus_t aclnnCreateMatmulDescriptor(infiniopAscendHandle_t handle,
MatmulAclnnDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
int8_t cubeMathType);
infiniopStatus_t aclnnGetMatmulWorkspaceSize(MatmulAclnnDescriptor_t desc,
size_t *size);
infiniopStatus_t aclnnMatmul(MatmulAclnnDescriptor_t desc, void *workspace,
size_t workspace_size, void *c, const void *a,
const void *b, float alpha, float beta,
void *stream);
infiniopStatus_t aclnnDestroyMatmulDescriptor(MatmulAclnnDescriptor_t desc);
#endif // __INFINIOP_MATMUL_ACLNN_API_H__
#include "matmul_aclnn.h" #include "matmul_ascend.h"
#include "../../../devices/ascend/tensor_aclnn.h"
InfiniopMatmulAclnnDescriptor::InfiniopMatmulAclnnDescriptor( #include "../../utils.h"
infiniDevice_t _device) { #include <acl/acl_base.h>
device = _device; #include <aclnn/acl_meta.h>
device_id = 0; #include <aclnnop/aclnn_matmul.h>
executor = nullptr; #include <aclnnop/level2/aclnn_gemm.h>
info = nullptr;
cDesc = new aclnnTensorDescriptor(); namespace matmul::ascend {
aDesc = new aclnnTensorDescriptor();
bDesc = new aclnnTensorDescriptor(); struct Descriptor::Opaque {
mt = 1; mutable aclOpExecutor *executor;
workspaceSize = 0; aclnnTensorDescriptor_t cDesc, aDesc, bDesc;
// cubeMathType
// see doc:
// https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC3alpha002/apiref/appdevgapi/context/aclnnBatchMatMul.md
int8_t mt;
~Opaque() {
delete cDesc;
delete aDesc;
delete bDesc;
aclDestroyAclOpExecutor(executor);
}
};
Descriptor::~Descriptor() {
delete _opaque;
} }
infiniopStatus_t aclnnCreateMatmulDescriptor(infiniopAscendHandle_t handle, infiniopStatus_t Descriptor::create(
MatmulAclnnDescriptor_t *desc_ptr, infiniopAscendHandle_t handle,
infiniopTensorDescriptor_t c_desc, Descriptor **desc_ptr,
infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t a_desc,
int8_t mt) { infiniopTensorDescriptor_t b_desc) {
infiniDtype_t dtype = c_desc->dtype; infiniDtype_t dtype = c_desc->dtype;
if (dtype != INFINI_DTYPE_F16 && dtype != INFINI_DTYPE_F32) { if (dtype != INFINI_DTYPE_F16 && dtype != INFINI_DTYPE_F32) {
return INFINIOP_STATUS_BAD_TENSOR_DTYPE; return INFINIOP_STATUS_BAD_TENSOR_DTYPE;
} }
*desc_ptr = new InfiniopMatmulAclnnDescriptor(handle->device);
(*desc_ptr)->device_id = handle->device_id;
(*desc_ptr)->dtype = dtype;
(*desc_ptr)->mt = mt;
infiniopStatus_t status; infiniopStatus_t status;
auto info = new MatmulInfo(c_desc, a_desc, b_desc, &status, false); auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::ROW_MAJOR);
if (status != INFINIOP_STATUS_SUCCESS) { if (status != INFINIOP_STATUS_SUCCESS) {
return status; return status;
} }
(*desc_ptr)->info = info;
auto &cDesc = (*desc_ptr)->cDesc; auto cDesc = new aclnnTensorDescriptor(),
auto &aDesc = (*desc_ptr)->aDesc; aDesc = new aclnnTensorDescriptor(),
auto &bDesc = (*desc_ptr)->bDesc; bDesc = new aclnnTensorDescriptor();
// Treat A, B, C as 2D matrix, reuse aclnnTensorDescriptor for batched // Treat A, B, C as 2D matrix, reuse aclnnTensorDescriptor for batched
// operation // operation
CHECK_STATUS(cDesc->setDescriptor( CHECK_STATUS(cDesc->setDescriptor(
toAclDataType(c_desc->dtype), toAclDataType(c_desc->dtype),
{static_cast<int64_t>(info->c_matrix.rows), {static_cast<int64_t>(info.c_matrix.rows),
static_cast<int64_t>(info->c_matrix.cols)}, static_cast<int64_t>(info.c_matrix.cols)},
{info->c_matrix.row_stride, info->c_matrix.col_stride}), {info.c_matrix.row_stride, info.c_matrix.col_stride}),
INFINIOP_STATUS_SUCCESS); INFINIOP_STATUS_SUCCESS);
CHECK_STATUS(aDesc->setDescriptor( CHECK_STATUS(aDesc->setDescriptor(
toAclDataType(a_desc->dtype), toAclDataType(a_desc->dtype),
{static_cast<int64_t>(info->a_matrix.rows), {static_cast<int64_t>(info.a_matrix.rows),
static_cast<int64_t>(info->a_matrix.cols)}, static_cast<int64_t>(info.a_matrix.cols)},
{info->a_matrix.row_stride, info->a_matrix.col_stride}), {info.a_matrix.row_stride, info.a_matrix.col_stride}),
INFINIOP_STATUS_SUCCESS); INFINIOP_STATUS_SUCCESS);
CHECK_STATUS(bDesc->setDescriptor( CHECK_STATUS(bDesc->setDescriptor(
toAclDataType(b_desc->dtype), toAclDataType(b_desc->dtype),
{static_cast<int64_t>(info->b_matrix.rows), {static_cast<int64_t>(info.b_matrix.rows),
static_cast<int64_t>(info->b_matrix.cols)}, static_cast<int64_t>(info.b_matrix.cols)},
{info->b_matrix.row_stride, info->b_matrix.col_stride}), {info.b_matrix.row_stride, info.b_matrix.col_stride}),
INFINIOP_STATUS_SUCCESS); INFINIOP_STATUS_SUCCESS);
CHECK_STATUS(cDesc->createTensor(), INFINIOP_STATUS_SUCCESS); CHECK_STATUS(cDesc->createTensor(), INFINIOP_STATUS_SUCCESS);
CHECK_STATUS(aDesc->createTensor(), INFINIOP_STATUS_SUCCESS); CHECK_STATUS(aDesc->createTensor(), INFINIOP_STATUS_SUCCESS);
CHECK_STATUS(bDesc->createTensor(), INFINIOP_STATUS_SUCCESS); CHECK_STATUS(bDesc->createTensor(), INFINIOP_STATUS_SUCCESS);
auto &workspaceSize = (*desc_ptr)->workspaceSize; auto tc = cDesc->t,
auto &executor = (*desc_ptr)->executor; ta = aDesc->t,
tb = bDesc->t;
aclTensor *tc = cDesc->t; aclOpExecutor *executor;
aclTensor *ta = aDesc->t; size_t workspaceSize;
aclTensor *tb = bDesc->t;
aclnnStatus ret;
int64_t transA = 0;
int64_t transB = 0;
// aclnnGemm support C = alpha * A @ B + beta * C // aclnnGemm support C = alpha * A @ B + beta * C
// see // see
// https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC3alpha003/apiref/aolapi/context/aclnnGemm.md // https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC3alpha003/apiref/aolapi/context/aclnnGemm.md
// use alpha = 0.5, beta = 0.5 temporarily // use alpha = 0.5, beta = 0.5 temporarily
ret = aclnnGemmGetWorkspaceSize(ta, tb, tc, 0.5f, 0.5f, transA, transB, tc,
(*desc_ptr)->mt, &workspaceSize, &executor); int8_t mt = 1;
auto ret = aclnnGemmGetWorkspaceSize(ta, tb, tc, .5, .5, 0, 0, tc, mt, &workspaceSize, &executor);
CHECK_RET(ret == ACL_SUCCESS, CHECK_RET(ret == ACL_SUCCESS,
LOG_PRINT("aclnnGemmGetWorkspaceSize failed. ERROR: %d\n", ret); LOG_PRINT("aclnnGemmGetWorkspaceSize failed. ERROR: %d\n", ret);
return INFINIOP_STATUS_INTERNAL_ERROR); return INFINIOP_STATUS_INTERNAL_ERROR);
aclSetAclOpExecutorRepeatable(executor); aclSetAclOpExecutorRepeatable(executor);
*desc_ptr = new Descriptor(
dtype, info, workspaceSize,
new Opaque{
executor,
cDesc,
aDesc,
bDesc,
mt,
},
handle->device, handle->device_id);
return INFINIOP_STATUS_SUCCESS; return INFINIOP_STATUS_SUCCESS;
} }
infiniopStatus_t aclnnGetMatmulWorkspaceSize(MatmulAclnnDescriptor_t desc, infiniopStatus_t Descriptor::calculate(
size_t *size) { void *workspace,
*size = desc->workspaceSize; size_t workspaceSize_,
return INFINIOP_STATUS_SUCCESS; void *c,
} float beta,
void const *a,
infiniopStatus_t aclnnMatmul(MatmulAclnnDescriptor_t desc, void *workspace, void const *b,
size_t workspace_size, void *c, void const *a, float alpha,
void const *b, float alpha, float beta, void *stream) const {
void *stream) {
auto &cDesc = desc->cDesc;
auto &aDesc = desc->aDesc;
auto &bDesc = desc->bDesc;
aclTensor *tc = cDesc->t;
aclTensor *ta = aDesc->t;
aclTensor *tb = bDesc->t;
auto batch = desc->info->batch; auto tc = _opaque->cDesc->t,
ta = _opaque->aDesc->t,
tb = _opaque->bDesc->t;
size_t workspaceSize; size_t workspaceSize;
aclnnStatus ret; auto ret = aclnnGemmGetWorkspaceSize(
ret = aclnnGemmGetWorkspaceSize(ta, tb, tc, alpha, beta, 0, 0, tc, desc->mt, ta, tb, tc, alpha, beta, 0, 0, tc, _opaque->mt,
&workspaceSize, &(desc->executor)); &workspaceSize, &(_opaque->executor));
CHECK_RET(ret == ACL_SUCCESS, CHECK_RET(ret == ACL_SUCCESS,
LOG_PRINT("aclnnGemmGetWorkspaceSize failed. ERROR: %d\n", ret); LOG_PRINT("aclnnGemmGetWorkspaceSize failed. ERROR: %d\n", ret);
return INFINIOP_STATUS_INTERNAL_ERROR); return INFINIOP_STATUS_INTERNAL_ERROR);
if (workspace_size < workspaceSize) { if (workspaceSize_ < workspaceSize) {
return INFINIOP_STATUS_INSUFFICIENT_WORKSPACE; return INFINIOP_STATUS_INSUFFICIENT_WORKSPACE;
} }
aclSetAclOpExecutorRepeatable(desc->executor); aclSetAclOpExecutorRepeatable(_opaque->executor);
for (size_t i = 0; i < batch; i++) { for (size_t i = 0; i < info.batch; ++i) {
AclSetTensorAddr(desc->executor, 0, ta, AclSetTensorAddr(_opaque->executor, 0, ta, ((char *)a) + i * info.a_matrix.stride * infiniSizeof(dtype));
(char *)(a) + i * desc->info->a_matrix.stride * infiniSizeof(desc->dtype)); AclSetTensorAddr(_opaque->executor, 1, tb, ((char *)b) + i * info.b_matrix.stride * infiniSizeof(dtype));
AclSetTensorAddr(desc->executor, 1, tb, AclSetTensorAddr(_opaque->executor, 2, tc, ((char *)c) + i * info.c_matrix.stride * infiniSizeof(dtype));
(char *)(b) + i * desc->info->b_matrix.stride * infiniSizeof(desc->dtype)); AclSetTensorAddr(_opaque->executor, 3, tc, ((char *)c) + i * info.c_matrix.stride * infiniSizeof(dtype));
AclSetTensorAddr(desc->executor, 2, tc, ret = aclnnGemm(workspace, workspaceSize, _opaque->executor, stream);
(char *)(c) + i * desc->info->c_matrix.stride * infiniSizeof(desc->dtype));
AclSetTensorAddr(desc->executor, 3, tc,
(char *)(c) + i * desc->info->c_matrix.stride * infiniSizeof(desc->dtype));
ret = aclnnGemm(workspace, workspaceSize, desc->executor, stream);
CHECK_RET(ret == ACL_SUCCESS, CHECK_RET(ret == ACL_SUCCESS,
LOG_PRINT("aclnnGemm failed. ERROR: %d\n", ret); LOG_PRINT("aclnnGemm failed. ERROR: %d\n", ret);
return INFINIOP_STATUS_INTERNAL_ERROR); return INFINIOP_STATUS_INTERNAL_ERROR);
...@@ -139,13 +145,4 @@ infiniopStatus_t aclnnMatmul(MatmulAclnnDescriptor_t desc, void *workspace, ...@@ -139,13 +145,4 @@ infiniopStatus_t aclnnMatmul(MatmulAclnnDescriptor_t desc, void *workspace,
return INFINIOP_STATUS_SUCCESS; return INFINIOP_STATUS_SUCCESS;
} }
infiniopStatus_t aclnnDestroyMatmulDescriptor(MatmulAclnnDescriptor_t desc) { } // namespace matmul::ascend
delete desc->cDesc;
delete desc->bDesc;
delete desc->aDesc;
delete desc->info;
aclDestroyAclOpExecutor(desc->executor);
delete desc;
return INFINIOP_STATUS_SUCCESS;
}
#ifndef __MATMUL_ASCEND_H__
#define __MATMUL_ASCEND_H__
#include "../../../devices/ascend/ascend_handle.h"
#include "../matmul.h"
DESCRIPTOR(ascend, infiniopAscendHandle_t)
#endif // __MATMUL_ASCEND_H__
#include "matmul_bang.h"
#include "../../../devices/bang/common_bang.h"
#include "../../utils.h"
#include <cnnl_extra.h>
namespace matmul::bang {
struct Descriptor::Opaque {
cnnlMatMulDescriptor_t opDesc;
cnnlMatMulAlgo_t algo;
cnnlMatMulHeuristicResult_t algoResult;
cnnlTensorDescriptor_t aDesc, bDesc, cDesc;
std::shared_ptr<Pool<cnnlHandle_t>> cnnl_handle_pool;
~Opaque() {
cnnlDestroyTensorDescriptor(aDesc);
cnnlDestroyTensorDescriptor(bDesc);
cnnlDestroyTensorDescriptor(cDesc);
cnnlMatMulDescDestroy(opDesc);
cnnlMatMulAlgoDestroy(algo);
cnnlDestroyMatMulHeuristicResult(algoResult);
}
};
static void setMatrixTensorEx(
cnnlTensorDescriptor_t desc,
const BlasMatrix &matrix, infiniDtype_t dtype,
bool trans = false) {
int ndim = matrix.ndim;
int batch = matrix.batch;
int stride = static_cast<int>(matrix.stride);
int rows = matrix.rows;
int cols = matrix.cols;
int row_stride = matrix.row_stride;
int col_stride = matrix.col_stride;
switch (ndim) {
case 3: {
std::vector<int> dim_size = {batch, rows, cols};
std::vector<int> dim_stride = {stride, row_stride, col_stride};
cnnlSetTensorDescriptorEx(
desc, CNNL_LAYOUT_ARRAY,
cnnlDataTypeConvert(dtype), dim_size.size(),
dim_size.data(), dim_stride.data());
} break;
case 2: {
std::vector<int> dim_size = {rows, cols};
std::vector<int> dim_stride = {row_stride, col_stride};
cnnlSetTensorDescriptorEx(
desc, CNNL_LAYOUT_ARRAY,
cnnlDataTypeConvert(dtype), dim_size.size(),
dim_size.data(), dim_stride.data());
} break;
}
}
Descriptor::~Descriptor() {
delete _opaque;
}
infiniopStatus_t Descriptor::create(
infiniopBangHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
infiniDtype_t dtype = c_desc->dtype;
if (dtype != INFINI_DTYPE_F16 && dtype != INFINI_DTYPE_F32) {
return INFINIOP_STATUS_BAD_TENSOR_DTYPE;
}
infiniopStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::ROW_MAJOR);
if (status != INFINIOP_STATUS_SUCCESS) {
return status;
}
cnnlTensorDescriptor_t aDesc, bDesc, cDesc;
cnnlCreateTensorDescriptor(&aDesc);
cnnlCreateTensorDescriptor(&bDesc);
cnnlCreateTensorDescriptor(&cDesc);
setMatrixTensorEx(aDesc, info.a_matrix, a_desc->dtype);
setMatrixTensorEx(bDesc, info.b_matrix, b_desc->dtype);
setMatrixTensorEx(cDesc, info.c_matrix, c_desc->dtype);
cnnlMatMulDescriptor_t opDesc;
cnnlMatMulAlgo_t algo;
cnnlMatMulHeuristicResult_t algoResult;
cnnlMatMulDescCreate(&opDesc);
cnnlMatMulAlgoCreate(&algo);
cnnlCreateMatMulHeuristicResult(&algoResult);
int32_t use_stride = true;
cnnlSetMatMulDescAttr(
opDesc,
CNNL_MATMUL_USE_STRIDE,
&use_stride,
sizeof(int32_t));
int count = 0;
use_cnnl(handle->cnnl_handle_pool,
[&](cnnlHandle_t _handle) {
cnnlGetBatchMatMulAlgoHeuristic(
_handle,
opDesc, aDesc, bDesc, cDesc,
NULL, 1, &algoResult, &count);
});
size_t workspace_size;
cnnlGetBatchMatMulHeuristicResult(algoResult, algo, &workspace_size);
*desc_ptr = new Descriptor(
dtype, info, workspace_size,
new Opaque{
opDesc,
algo,
algoResult,
aDesc,
bDesc,
cDesc,
handle->cnnl_handle_pool},
handle->device, handle->device_id);
return INFINIOP_STATUS_SUCCESS;
}
infiniopStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *c,
float beta,
void const *a,
void const *b,
float alpha,
void *stream) const {
if (info.is_transed) {
std::swap(a, b);
}
use_cnnl(_opaque->cnnl_handle_pool,
(cnrtQueue_t)stream,
[&](cnnlHandle_t handle) {
cnnlBatchMatMulBCast_v2(
handle,
_opaque->opDesc,
_opaque->algo,
&alpha,
_opaque->aDesc, a,
_opaque->bDesc, b,
&beta,
_opaque->cDesc, c,
workspace,
workspace_size);
});
cnrtQueueSync((cnrtQueue_t)stream);
return INFINIOP_STATUS_SUCCESS;
}
} // namespace matmul::bang
#ifndef __MATMUL_BANG_H__
#define __MATMUL_BANG_H__
#include "../../../devices/bang/bang_handle.h"
#include "../matmul.h"
DESCRIPTOR(bang, infiniopBangHandle_t)
#endif // __MATMUL_BANG_H__
#include "matmul_cnnl.h"
#include "../../../devices/bang/common_bang.h"
#include "../../utils.h"
#include "matmul_cnnl_api.h"
infiniopStatus_t bangCreateMatmulDescriptor(
infiniopBangHandle_t handle, infiniopMatmulBangDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
infiniopStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, false);
if (status != INFINIOP_STATUS_SUCCESS) {
return status;
}
cnnlTensorDescriptor_t aDesc, bDesc, cDesc;
cnnlCreateTensorDescriptor(&aDesc);
cnnlCreateTensorDescriptor(&bDesc);
cnnlCreateTensorDescriptor(&cDesc);
setMatrixTensorEx(aDesc, info.a_matrix, a_desc->dtype);
setMatrixTensorEx(bDesc, info.b_matrix, b_desc->dtype);
setMatrixTensorEx(cDesc, info.c_matrix, c_desc->dtype);
cnnlMatMulDescriptor_t opDesc;
cnnlMatMulAlgo_t algo;
cnnlMatMulHeuristicResult_t algoResult;
cnnlMatMulDescCreate(&opDesc);
cnnlMatMulAlgoCreate(&algo);
cnnlCreateMatMulHeuristicResult(&algoResult);
int32_t use_stride = true;
cnnlSetMatMulDescAttr(opDesc, CNNL_MATMUL_USE_STRIDE, &use_stride,
sizeof(int32_t));
int count = 0;
use_cnnl(handle->cnnl_handle_pool, [&](cnnlHandle_t _handle) {
cnnlGetBatchMatMulAlgoHeuristic(_handle, opDesc, aDesc, bDesc, cDesc,
NULL, 1, &algoResult, &count);
});
size_t workspace_size;
cnnlGetBatchMatMulHeuristicResult(algoResult, algo, &workspace_size);
*desc_ptr = new InfiniopMatmulBangDescriptor{handle->device,
handle->device_id,
info,
c_desc->dtype,
handle->cnnl_handle_pool,
aDesc,
bDesc,
cDesc,
opDesc,
algo,
algoResult,
workspace_size};
return INFINIOP_STATUS_SUCCESS;
}
infiniopStatus_t bangGetMatmulWorkspaceSize(infiniopMatmulBangDescriptor_t desc,
size_t *size) {
*size = desc->workspace_size;
return INFINIOP_STATUS_SUCCESS;
}
infiniopStatus_t
bangDestroyMatmulDescriptor(infiniopMatmulBangDescriptor_t desc) {
desc->cnnl_handle_pool = nullptr;
cnnlDestroyTensorDescriptor(desc->aDesc);
cnnlDestroyTensorDescriptor(desc->bDesc);
cnnlDestroyTensorDescriptor(desc->cDesc);
cnnlMatMulDescDestroy(desc->opDesc);
cnnlMatMulAlgoDestroy(desc->algo);
cnnlDestroyMatMulHeuristicResult(desc->algoResult);
delete desc;
return INFINIOP_STATUS_SUCCESS;
}
void bangMatmulCnnl(infiniopMatmulBangDescriptor_t desc, void *workspace, void *c,
float beta, void const *a, void const *b, float alpha,
void *stream) {
auto info = desc->info;
if (info.is_transed) {
std::swap(a, b);
}
use_cnnl(desc->cnnl_handle_pool, (cnrtQueue_t)stream, [&](cnnlHandle_t handle) {
cnnlBatchMatMulBCast_v2(handle, desc->opDesc, desc->algo, &alpha,
desc->aDesc, a, desc->bDesc, b, &beta,
desc->cDesc, c, workspace,
desc->workspace_size);
});
}
infiniopStatus_t bangMatmul(infiniopMatmulBangDescriptor_t desc,
void *workspace, size_t workspace_size, void *c,
void const *a, void const *b, float alpha,
float beta, void *stream) {
if (desc->dtype == INFINI_DTYPE_F16 || desc->dtype == INFINI_DTYPE_F32) {
bangMatmulCnnl(desc, workspace, c, beta, a, b, alpha, stream);
cnrtQueueSync((cnrtQueue_t)stream);
return INFINIOP_STATUS_SUCCESS;
}
return INFINIOP_STATUS_BAD_TENSOR_DTYPE;
}
#ifndef __CNNL_MATMUL_H__
#define __CNNL_MATMUL_H__
#include "../../../devices/bang/common_bang.h"
#include "../blas.h"
#include "cnnl_extra.h"
struct InfiniopMatmulBangDescriptor {
infiniDevice_t device;
int device_id;
MatmulInfo info;
infiniDtype_t dtype;
std::shared_ptr<Pool<cnnlHandle_t>> cnnl_handle_pool;
cnnlTensorDescriptor_t aDesc;
cnnlTensorDescriptor_t bDesc;
cnnlTensorDescriptor_t cDesc;
cnnlMatMulDescriptor_t opDesc;
cnnlMatMulAlgo_t algo;
cnnlMatMulHeuristicResult_t algoResult;
size_t workspace_size;
};
inline void setMatrixTensorEx(cnnlTensorDescriptor_t desc,
const BlasMatrix &matrix, infiniDtype_t dtype,
bool trans = false) {
int ndim = matrix.ndim;
int batch = matrix.batch;
int stride = static_cast<int>(matrix.stride);
int rows = matrix.rows;
int cols = matrix.cols;
int row_stride = matrix.row_stride;
int col_stride = matrix.col_stride;
if (ndim == 3) {
std::vector<int> dim_size = {batch, rows, cols};
std::vector<int> dim_stride = {stride, row_stride, col_stride};
cnnlSetTensorDescriptorEx(desc, CNNL_LAYOUT_ARRAY,
cnnlDataTypeConvert(dtype), dim_size.size(),
dim_size.data(), dim_stride.data());
} else if (ndim == 2) {
std::vector<int> dim_size = {rows, cols};
std::vector<int> dim_stride = {row_stride, col_stride};
cnnlSetTensorDescriptorEx(desc, CNNL_LAYOUT_ARRAY,
cnnlDataTypeConvert(dtype), dim_size.size(),
dim_size.data(), dim_stride.data());
}
}
#endif // __CNNL_MATMUL_H__
#ifndef __CNNL_MATMUL_API_H__
#define __CNNL_MATMUL_API_H__
#include "../../../devices/bang/bang_handle.h"
#include "infiniop/operator.h"
struct InfiniopMatmulBangDescriptor;
typedef struct InfiniopMatmulBangDescriptor *infiniopMatmulBangDescriptor_t;
infiniopStatus_t bangCreateMatmulDescriptor(
infiniopBangHandle_t handle, infiniopMatmulBangDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc);
infiniopStatus_t bangGetMatmulWorkspaceSize(infiniopMatmulBangDescriptor_t desc,
size_t *size);
infiniopStatus_t bangMatmul(infiniopMatmulBangDescriptor_t desc,
void *workspace, size_t workspace_size, void *c,
void const *a, void const *b, float alpha,
float beta, void *stream);
infiniopStatus_t
bangDestroyMatmulDescriptor(infiniopMatmulBangDescriptor_t desc);
#endif
#ifndef __BLAS_H__ #ifndef __BLAS_H__
#define __BLAS_H__ #define __BLAS_H__
#include "../utils.h"
#include "infiniop/operator.h" #include "infiniop/operator.h"
#include <algorithm> #include <algorithm>
#include <stdint.h>
typedef struct BlasMatrix { namespace matmul {
struct BlasMatrix {
size_t ndim; size_t ndim;
size_t batch; size_t batch;
ptrdiff_t stride; ptrdiff_t stride;
...@@ -15,31 +14,31 @@ typedef struct BlasMatrix { ...@@ -15,31 +14,31 @@ typedef struct BlasMatrix {
ptrdiff_t row_stride; ptrdiff_t row_stride;
ptrdiff_t col_stride; ptrdiff_t col_stride;
BlasMatrix() {} BlasMatrix() = default;
BlasMatrix(infiniopTensorDescriptor_t layout, infiniopStatus_t *status) { BlasMatrix(infiniopTensorDescriptor_t layout, infiniopStatus_t *status) {
if (layout->ndim == 2) { if (layout->ndim == 2) {
this->ndim = 2; ndim = 2;
this->batch = 1; batch = 1;
this->stride = 0; stride = 0;
this->rows = layout->shape[0]; rows = layout->shape[0];
this->cols = layout->shape[1]; cols = layout->shape[1];
this->row_stride = layout->strides[0]; row_stride = layout->strides[0];
this->col_stride = layout->strides[1]; col_stride = layout->strides[1];
} else if (layout->ndim == 3) { } else if (layout->ndim == 3) {
this->ndim = 3; ndim = 3;
this->batch = layout->shape[0]; batch = layout->shape[0];
this->stride = this->batch == 1 ? 0 : layout->strides[0]; stride = batch == 1 ? 0 : layout->strides[0];
this->rows = layout->shape[1]; rows = layout->shape[1];
this->cols = layout->shape[2]; cols = layout->shape[2];
this->row_stride = layout->strides[1]; row_stride = layout->strides[1];
this->col_stride = layout->strides[2]; col_stride = layout->strides[2];
} else { } else {
*status = INFINIOP_STATUS_BAD_TENSOR_SHAPE; *status = INFINIOP_STATUS_BAD_TENSOR_SHAPE;
return; return;
} }
if (this->row_stride != 1 && this->col_stride != 1) { if (row_stride != 1 && col_stride != 1) {
*status = INFINIOP_STATUS_BAD_TENSOR_STRIDES; *status = INFINIOP_STATUS_BAD_TENSOR_STRIDES;
return; return;
} }
...@@ -48,7 +47,7 @@ typedef struct BlasMatrix { ...@@ -48,7 +47,7 @@ typedef struct BlasMatrix {
} }
bool match_batch(size_t _batch) const { bool match_batch(size_t _batch) const {
return this->batch == _batch || this->batch == 1; return batch == _batch || batch == 1;
} }
void transpose() { void transpose() {
...@@ -57,13 +56,14 @@ typedef struct BlasMatrix { ...@@ -57,13 +56,14 @@ typedef struct BlasMatrix {
} }
ptrdiff_t ld() const { ptrdiff_t ld() const {
if (this->row_stride == 1) { return row_stride == 1 ? col_stride : row_stride;
return this->col_stride;
} else {
return this->row_stride;
}
} }
} BlasMatrix; };
enum class MatrixLayout : uint8_t {
COL_MAJOR,
ROW_MAJOR,
};
struct MatmulInfo { struct MatmulInfo {
BlasMatrix a_matrix; BlasMatrix a_matrix;
...@@ -74,7 +74,11 @@ struct MatmulInfo { ...@@ -74,7 +74,11 @@ struct MatmulInfo {
bool is_transed = false; bool is_transed = false;
MatmulInfo(infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc, infiniopStatus_t *status, bool col_major = true) { MatmulInfo(infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopStatus_t *status,
MatrixLayout layout) {
a_matrix = BlasMatrix(a_desc, status); a_matrix = BlasMatrix(a_desc, status);
if (*status != INFINIOP_STATUS_SUCCESS) { if (*status != INFINIOP_STATUS_SUCCESS) {
return; return;
...@@ -99,7 +103,8 @@ struct MatmulInfo { ...@@ -99,7 +103,8 @@ struct MatmulInfo {
return; return;
} }
if ((col_major && c_matrix.col_stride == 1) || (!col_major && c_matrix.row_stride == 1)) { if ((layout == MatrixLayout::COL_MAJOR && c_matrix.col_stride == 1)
|| (layout == MatrixLayout::ROW_MAJOR && c_matrix.row_stride == 1)) {
c_matrix.transpose(); c_matrix.transpose();
b_matrix.transpose(); b_matrix.transpose();
a_matrix.transpose(); a_matrix.transpose();
...@@ -112,5 +117,6 @@ struct MatmulInfo { ...@@ -112,5 +117,6 @@ struct MatmulInfo {
k = a_matrix.cols; k = a_matrix.cols;
} }
}; };
} // namespace matmul
#endif // __BLAS_H__ #endif // __BLAS_H__
#include "./matmul_cpu.h" #include "./matmul_cpu.h"
#include "../../../devices/cpu/common_cpu.h" #include "../../../devices/cpu/common_cpu.h"
#include "../../utils.h" #include <iostream>
#include <cmath>
infiniopStatus_t cpuCreateMatmulDescriptor( namespace matmul::cpu {
infiniopCpuHandle_t handle, infiniopMatmulCpuDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t a_desc, Descriptor::~Descriptor() = default;
infiniopStatus_t Descriptor::create(
infiniopCpuHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) { infiniopTensorDescriptor_t b_desc) {
infiniDtype_t dtype = c_desc->dtype; infiniDtype_t dtype = c_desc->dtype;
...@@ -14,32 +19,26 @@ infiniopStatus_t cpuCreateMatmulDescriptor( ...@@ -14,32 +19,26 @@ infiniopStatus_t cpuCreateMatmulDescriptor(
} }
infiniopStatus_t status; infiniopStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status); auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::COL_MAJOR);
if (status != INFINIOP_STATUS_SUCCESS) { if (status != INFINIOP_STATUS_SUCCESS) {
return status; return status;
} }
*desc_ptr = new MatmulCpuDescriptor{INFINI_DEVICE_CPU, dtype, info}; *desc_ptr = new Descriptor(
dtype, info, 0,
return INFINIOP_STATUS_SUCCESS; nullptr,
} handle->device, handle->device_id);
infiniopStatus_t cpuGetMatmulWorkspaceSize(infiniopMatmulCpuDescriptor_t desc,
size_t *size) {
*size = 0;
return INFINIOP_STATUS_SUCCESS;
}
infiniopStatus_t
cpuDestroyMatmulDescriptor(infiniopMatmulCpuDescriptor_t desc) {
delete desc;
return INFINIOP_STATUS_SUCCESS; return INFINIOP_STATUS_SUCCESS;
} }
template <typename Tdata> template <typename Tdata>
infiniopStatus_t cpuCalculateMatmul(infiniopMatmulCpuDescriptor_t desc, void *c, void calculate(
float beta, void const *a, void const *b, Descriptor const *desc,
float alpha) { void *c,
float beta,
void const *a,
void const *b,
float alpha) {
auto info = desc->info; auto info = desc->info;
if (info.is_transed) { if (info.is_transed) {
...@@ -72,17 +71,30 @@ infiniopStatus_t cpuCalculateMatmul(infiniopMatmulCpuDescriptor_t desc, void *c, ...@@ -72,17 +71,30 @@ infiniopStatus_t cpuCalculateMatmul(infiniopMatmulCpuDescriptor_t desc, void *c,
} }
} }
} }
return INFINIOP_STATUS_SUCCESS;
} }
infiniopStatus_t cpuMatmul(infiniopMatmulCpuDescriptor_t desc, void *workspace, infiniopStatus_t Descriptor::calculate(
size_t workspace_size, void *c, void const *a, void *workspace,
void const *b, float alpha, float beta) { size_t workspace_size,
if (desc->dtype == INFINI_DTYPE_F16) { void *c,
return cpuCalculateMatmul<uint16_t>(desc, c, beta, a, b, alpha); float beta,
} void const *a,
if (desc->dtype == INFINI_DTYPE_F32) { void const *b,
return cpuCalculateMatmul<float>(desc, c, beta, a, b, alpha); float alpha,
void *stream) const {
switch (dtype) {
case INFINI_DTYPE_F16:
cpu::calculate<uint16_t>(this, c, beta, a, b, alpha);
return INFINIOP_STATUS_SUCCESS;
case INFINI_DTYPE_F32:
cpu::calculate<float>(this, c, beta, a, b, alpha);
return INFINIOP_STATUS_SUCCESS;
default:
return INFINIOP_STATUS_BAD_TENSOR_DTYPE;
} }
return INFINIOP_STATUS_BAD_TENSOR_DTYPE;
} }
} // namespace matmul::cpu
#ifndef __INFINIOP_MATMUL_CPU_H__ #ifndef __MATMUL_CPU_H__
#define __INFINIOP_MATMUL_CPU_H__ #define __MATMUL_CPU_H__
#include "../blas.h" #include "../../../devices/cpu/cpu_handle.h"
#include "./matmul_cpu_api.h" #include "../matmul.h"
typedef struct MatmulCpuDescriptor { DESCRIPTOR(cpu, infiniopCpuHandle_t)
infiniDevice_t device;
infiniDtype_t dtype;
MatmulInfo info;
} MatmulCpuDescriptor;
#endif // __INFINIOP_MATMUL_CPU_H__ #endif // __MATMUL_CPU_H__
#ifndef __INFINIOP_MATMUL_CPU_API_H__
#define __INFINIOP_MATMUL_CPU_API_H__
#include "../../../devices/cpu/cpu_handle.h"
#include "infiniop/operator.h"
struct MatmulCpuDescriptor;
typedef struct MatmulCpuDescriptor *infiniopMatmulCpuDescriptor_t;
infiniopStatus_t cpuCreateMatmulDescriptor(
infiniopCpuHandle_t handle, infiniopMatmulCpuDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc);
infiniopStatus_t cpuGetMatmulWorkspaceSize(infiniopMatmulCpuDescriptor_t desc,
size_t *size);
infiniopStatus_t cpuMatmul(infiniopMatmulCpuDescriptor_t desc, void *workspace,
size_t workspace_size, void *c, void const *a,
void const *b, float alpha, float beta);
infiniopStatus_t cpuDestroyMatmulDescriptor(infiniopMatmulCpuDescriptor_t desc);
#endif // __INFINIOP_MATMUL_CPU_API_H__
#include "../../utils.h" #include "../../utils.h"
#include "./matmul_cuda.cuh" #include "matmul_cuda.cuh"
infiniopStatus_t cudaCreateMatmulDescriptor(infiniopCudaHandle_t handle, namespace matmul::cuda {
infiniopMatmulCudaDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c_desc, struct Descriptor::Opaque {
infiniopTensorDescriptor_t a_desc, std::shared_ptr<Pool<cublasHandle_t>> cublas_handle_pool;
infiniopTensorDescriptor_t b_desc) { };
Descriptor::~Descriptor() {
delete _opaque;
}
infiniopStatus_t Descriptor::create(
infiniopCudaHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
infiniDtype_t dtype = c_desc->dtype; infiniDtype_t dtype = c_desc->dtype;
if (dtype != INFINI_DTYPE_F16 && dtype != INFINI_DTYPE_F32) { if (dtype != INFINI_DTYPE_F16 && dtype != INFINI_DTYPE_F32) {
...@@ -13,27 +24,103 @@ infiniopStatus_t cudaCreateMatmulDescriptor(infiniopCudaHandle_t handle, ...@@ -13,27 +24,103 @@ infiniopStatus_t cudaCreateMatmulDescriptor(infiniopCudaHandle_t handle,
} }
infiniopStatus_t status; infiniopStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status); auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::COL_MAJOR);
if (status != INFINIOP_STATUS_SUCCESS) { if (status != INFINIOP_STATUS_SUCCESS) {
return status; return status;
} }
*desc_ptr = new InfiniopMatmulCudaDescriptor{ *desc_ptr = new Descriptor(
handle->device, dtype, info, 0,
dtype, new Opaque{handle->cublas_handle_pool},
handle->device_id, handle->device, handle->device_id);
info,
handle->cublas_handle_pool};
return INFINIOP_STATUS_SUCCESS; return INFINIOP_STATUS_SUCCESS;
} }
infiniopStatus_t cudaGetMatmulWorkspaceSize(infiniopMatmulCudaDescriptor_t desc, size_t *size) { template <typename Tdata>
*size = 0; infiniopStatus_t calculate(
MatmulInfo const &info,
std::shared_ptr<Pool<cublasHandle_t>> &cublas_handle_pool,
void *c,
float beta,
void const *a,
void const *b,
float alpha,
cudaStream_t stream) {
if (info.is_transed) {
std::swap(a, b);
}
cudaDataType a_type, b_type, c_type;
cublasComputeType_t compute_type;
if constexpr (std::is_same<Tdata, half>::value) {
a_type = b_type = c_type = CUDA_R_16F;
compute_type = CUBLAS_COMPUTE_32F;
} else {
a_type = b_type = c_type = CUDA_R_32F;
#ifdef ENABLE_SUGON_CUDA_API
compute_type = CUBLAS_COMPUTE_32F;
#else
compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
#endif
}
auto op_a = info.a_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
auto op_b = info.b_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
use_cublas(cublas_handle_pool,
stream,
[&](cublasHandle_t handle) {
cublasGemmStridedBatchedEx(
handle,
op_a,
op_b,
static_cast<int>(info.m),
static_cast<int>(info.n),
static_cast<int>(info.k),
&alpha,
a,
a_type,
static_cast<int>(info.a_matrix.ld()),
info.a_matrix.stride,
b,
b_type,
static_cast<int>(info.b_matrix.ld()),
info.b_matrix.stride,
&beta,
c,
c_type,
static_cast<int>(info.c_matrix.ld()),
info.c_matrix.stride,
static_cast<int>(info.batch),
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
});
return INFINIOP_STATUS_SUCCESS; return INFINIOP_STATUS_SUCCESS;
} }
infiniopStatus_t cudaDestroyMatmulDescriptor(infiniopMatmulCudaDescriptor_t desc) { infiniopStatus_t Descriptor::calculate(
desc->cublas_handle_pool = nullptr; void *workspace,
delete desc; size_t workspace_size,
return INFINIOP_STATUS_SUCCESS; void *c,
float beta,
void const *a,
void const *b,
float alpha,
void *stream) const {
switch (dtype) {
case INFINI_DTYPE_F16:
cuda::calculate<uint16_t>(info, _opaque->cublas_handle_pool, c, beta, a, b, alpha, (cudaStream_t)stream);
return INFINIOP_STATUS_SUCCESS;
case INFINI_DTYPE_F32:
cuda::calculate<float>(info, _opaque->cublas_handle_pool, c, beta, a, b, alpha, (cudaStream_t)stream);
return INFINIOP_STATUS_SUCCESS;
default:
return INFINIOP_STATUS_BAD_TENSOR_DTYPE;
}
} }
} // namespace matmul::cuda
#ifndef __INFINIOP_MATMUL_CUDA_H__ #ifndef __MATMUL_CUDA_CUH__
#define __INFINIOP_MATMUL_CUDA_H__ #define __MATMUL_CUDA_CUH__
#include "../../../devices/cuda/common_cuda.cuh" #include "../../../devices/cuda/cuda_handle.h"
#include "../blas.h" #include "../matmul.h"
#include "matmul_cuda_api.h"
#include <memory>
typedef struct InfiniopMatmulCudaDescriptor { DESCRIPTOR(cuda, infiniopCudaHandle_t)
infiniDevice_t device;
infiniDtype_t dtype;
int device_id;
MatmulInfo info;
std::shared_ptr<Pool<cublasHandle_t>> cublas_handle_pool;
} InfiniopMatmulCudaDescriptor;
#endif // __INFINIOP_MATMUL_CUDA_H__ #endif // __MATMUL_CUDA_CUH__
#ifndef __INFINIOP_MATMUL_CUDA_API_H__
#define __INFINIOP_MATMUL_CUDA_API_H__
#include "../../../devices/cuda/cuda_handle.h"
#include "infiniop/operator.h"
struct InfiniopMatmulCudaDescriptor;
typedef struct InfiniopMatmulCudaDescriptor *infiniopMatmulCudaDescriptor_t;
infiniopStatus_t cudaCreateMatmulDescriptor(infiniopCudaHandle_t handle,
infiniopMatmulCudaDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc);
infiniopStatus_t cudaGetMatmulWorkspaceSize(infiniopMatmulCudaDescriptor_t desc, size_t *size);
infiniopStatus_t cudaMatmul(infiniopMatmulCudaDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c,
void const *a,
void const *b,
float alpha,
float beta,
void *stream);
infiniopStatus_t cudaDestroyMatmulDescriptor(infiniopMatmulCudaDescriptor_t desc);
#endif // __INFINIOP_MATMUL_CUDA_API_H__
#include "../../utils.h"
#include "./matmul_cuda.cuh"
template <typename Tdata>
infiniopStatus_t cudaMatmulCublas(infiniopMatmulCudaDescriptor_t desc, void *c, float beta, void const *a, void const *b, float alpha, void *stream) {
auto info = desc->info;
if (info.is_transed) {
std::swap(a, b);
}
cudaDataType a_type, b_type, c_type;
cublasComputeType_t compute_type;
if constexpr (std::is_same<Tdata, half>::value) {
a_type = b_type = c_type = CUDA_R_16F;
compute_type = CUBLAS_COMPUTE_32F;
} else {
a_type = b_type = c_type = CUDA_R_32F;
#ifdef ENABLE_SUGON_CUDA_API
compute_type = CUBLAS_COMPUTE_32F;
#else
compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
#endif
}
auto op_a = info.a_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
auto op_b = info.b_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
use_cublas(desc->cublas_handle_pool, desc->device_id, (cudaStream_t)stream,
[&](cublasHandle_t handle) { cublasGemmStridedBatchedEx(
handle,
op_a,
op_b,
static_cast<int>(info.m),
static_cast<int>(info.n),
static_cast<int>(info.k),
&alpha,
a,
a_type,
static_cast<int>(info.a_matrix.ld()),
info.a_matrix.stride,
b,
b_type,
static_cast<int>(info.b_matrix.ld()),
info.b_matrix.stride,
&beta,
c,
c_type,
static_cast<int>(info.c_matrix.ld()),
info.c_matrix.stride,
static_cast<int>(info.batch),
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP); });
return INFINIOP_STATUS_SUCCESS;
}
infiniopStatus_t cudaMatmul(infiniopMatmulCudaDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c,
void const *a,
void const *b,
float alpha,
float beta,
void *stream) {
if (desc->dtype == INFINI_DTYPE_F16) {
return cudaMatmulCublas<half>(desc, c, beta, a, b, alpha, stream);
}
if (desc->dtype == INFINI_DTYPE_F32) {
return cudaMatmulCublas<float>(desc, c, beta, a, b, alpha, stream);
}
return INFINIOP_STATUS_BAD_TENSOR_DTYPE;
}
#ifndef __MATMUL_H__
#define __MATMUL_H__
#include "blas.h"
#include "infiniop/operator.h"
#define DESCRIPTOR(NAMESPACE, HANDLE) \
\
namespace matmul::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
\
Descriptor( \
infiniDtype_t dtype_, \
MatmulInfo info_, \
size_t workspace_size_, \
Opaque *opaque, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
dtype(dtype_), \
info(info_), \
workspace_size(workspace_size_) {} \
\
public: \
infiniDtype_t dtype; \
MatmulInfo info; \
size_t workspace_size; \
\
~Descriptor(); \
\
static infiniopStatus_t create( \
HANDLE handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t c_desc, \
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t b_desc); \
\
infiniopStatus_t calculate( \
void *workspace, \
size_t workspace_size, \
void *c, \
float beta, \
void const *a, \
void const *b, \
float alpha, \
void *stream) const; \
}; \
}
#endif // __MATMUL_H__
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment