Commit c9dbc0ff authored by zhangyue's avatar zhangyue
Browse files

issue/25: restruct matmul

parent 8583cae7
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#include "ascend/ascend_handle.h" #include "ascend/ascend_handle.h"
#endif #endif
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
#include "./kunlun/kunlun_handle.h" #include "kunlun/kunlun_handle.h"
#endif #endif
__C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr, __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr,
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
namespace xdnn = baidu::xpu::api; namespace xdnn = baidu::xpu::api;
typedef xdnn::Context *xdnnHandle_t; typedef xdnn::Context *xdnnHandle_t;
typedef XPUStream KunlunStream_t;
#define CHECK_KUNLUN(call) \ #define CHECK_KUNLUN(call) \
{ \ { \
...@@ -18,7 +19,7 @@ typedef xdnn::Context *xdnnHandle_t; ...@@ -18,7 +19,7 @@ typedef xdnn::Context *xdnnHandle_t;
if (XPU_SUCCESS != err) { \ if (XPU_SUCCESS != err) { \
fprintf(stderr, "KUNLUN error in %s:%i : %s.\n", __FILE__, \ fprintf(stderr, "KUNLUN error in %s:%i : %s.\n", __FILE__, \
__LINE__, xpu_strerror(err)); \ __LINE__, xpu_strerror(err)); \
return INFINIOP_STATUS_INTERNAL_ERROR; \ return INFINI_STATUS_INTERNAL_ERROR; \
} \ } \
} }
...@@ -29,17 +30,14 @@ struct InfiniopKunlunHandle { ...@@ -29,17 +30,14 @@ struct InfiniopKunlunHandle {
}; };
template <typename T> template <typename T>
infiniopStatus_t use_xdnn(std::shared_ptr<Pool<xdnnHandle_t>> xdnn_handle_pool, void use_xdnn(std::shared_ptr<Pool<xdnnHandle_t>> &pool, KunlunStream_t stream, T const &f) {
XPUStream stream, auto handle = pool->pop();
T const &f) {
auto handle = xdnn_handle_pool->pop();
if (!handle) { if (!handle) {
*handle = xdnn::create_context(); *handle = xdnn::create_context();
} }
(*handle)->set_stream(stream); (*handle)->set_stream(stream);
auto ret = f(*handle); f(*handle);
xdnn_handle_pool->push(std::move(*handle)); pool->push(std::move(*handle));
return ret;
} }
#endif #endif //__INFINIOP_COMMON_KUNLUN_H__
#include "common_kunlun.h" #include "common_kunlun.h"
infiniopStatus_t createKunlunHandle(infiniopKunlunHandle_t *handle_ptr) { infiniStatus_t createKunlunHandle(infiniopKunlunHandle_t *handle_ptr) {
int device_id; int device_id;
CHECK_KUNLUN(xpu_current_device(&device_id)) CHECK_KUNLUN(xpu_current_device(&device_id));
auto pool = std::make_shared<Pool<xdnnHandle_t>>(); auto pool = std::make_shared<Pool<xdnnHandle_t>>();
xdnnHandle_t handle = xdnn::create_context(); xdnnHandle_t handle = xdnn::create_context();
pool->push(std::move(handle)); pool->push(std::move(handle));
...@@ -14,11 +13,12 @@ infiniopStatus_t createKunlunHandle(infiniopKunlunHandle_t *handle_ptr) { ...@@ -14,11 +13,12 @@ infiniopStatus_t createKunlunHandle(infiniopKunlunHandle_t *handle_ptr) {
std::move(pool), std::move(pool),
}; };
return INFINIOP_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
infiniopStatus_t destroyKunlunHandle(infiniopKunlunHandle_t handle) { infiniStatus_t destroyKunlunHandle(infiniopKunlunHandle_t handle_ptr) {
handle->xdnn_handle_pool = nullptr; handle_ptr->xdnn_handle_pool = nullptr;
delete handle; delete handle_ptr;
return INFINIOP_STATUS_SUCCESS;
return INFINI_STATUS_SUCCESS;
} }
...@@ -7,7 +7,7 @@ struct InfiniopKunlunHandle; ...@@ -7,7 +7,7 @@ struct InfiniopKunlunHandle;
typedef struct InfiniopKunlunHandle *infiniopKunlunHandle_t; typedef struct InfiniopKunlunHandle *infiniopKunlunHandle_t;
infiniopStatus_t createKunlunHandle(infiniopKunlunHandle_t *handle_ptr); infiniStatus_t createKunlunHandle(infiniopKunlunHandle_t *handle_ptr);
infiniopStatus_t destroyKunlunHandle(infiniopKunlunHandle_t handle); infiniStatus_t destroyKunlunHandle(infiniopKunlunHandle_t handle);
#endif #endif // __INFINIOP_KUNLUN_HANDLE_H__
#include "matmul_kunlun.h"
#include "../../../devices/kunlun/common_kunlun.h"
#include "../../utils.h"
namespace matmul::kunlun {
struct Descriptor::Opaque {
std::shared_ptr<Pool<xdnnHandle_t>> xdnn_handle_pool;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
auto handle = reinterpret_cast<infiniopKunlunHandle_t>(handle_);
auto dtype = c_desc->dtype;
if (dtype != INFINI_DTYPE_F16 && dtype != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
infiniStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::ROW_MAJOR);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
*desc_ptr = new Descriptor(
dtype, info, 0,
new Opaque{handle->xdnn_handle_pool},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <class Tdata>
void calculate(
const MatmulInfo &info,
std::shared_ptr<Pool<xdnnHandle_t>> &xdnn_handle_pool,
infiniDtype_t dtype,
void *c,
float beta,
void const *a,
void const *b,
float alpha,
KunlunStream_t stream) {
if (info.is_transed) {
std::swap(a, b);
}
auto transA = info.a_matrix.col_stride == 1 ? false : true;
auto transB = info.b_matrix.col_stride == 1 ? false : true;
auto unit = infiniSizeof(dtype);
use_xdnn(xdnn_handle_pool,
(KunlunStream_t)stream,
[&](xdnnHandle_t handle) {
for (size_t i = 0; i < info.batch; i++) {
xdnn::fc_fusion<Tdata, Tdata, Tdata, int16_t>(
handle,
(Tdata *)((char *)a + i * info.a_matrix.stride * unit),
(Tdata *)((char *)b + i * info.b_matrix.stride * unit),
(Tdata *)((char *)c + i * info.c_matrix.stride * unit),
info.m,
info.n,
info.k,
transA,
transB,
nullptr,
nullptr,
nullptr,
info.a_matrix.ld(),
info.b_matrix.ld(),
info.c_matrix.ld(),
alpha,
beta,
nullptr,
xdnn::Activation_t::LINEAR,
nullptr);
}
});
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t worksapce_size,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
void *stream) const {
switch (_dtype) {
case INFINI_DTYPE_F16:
kunlun::calculate<float16>(_info, _opaque->xdnn_handle_pool, _dtype, c, beta, a, b, alpha, (KunlunStream_t)stream);
return INFINI_STATUS_SUCCESS;
case INFINI_DTYPE_F32:
kunlun::calculate<float>(_info, _opaque->xdnn_handle_pool, _dtype, c, beta, a, b, alpha, (KunlunStream_t)stream);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
} // namespace matmul::kunlun
#ifndef __MATMUL_KUNLUN_H__
#define __MATMUL_KUNLUN_H__
#include "../matmul.h"
DESCRIPTOR(kunlun)
#endif // __MATMUL_KUNLUN_H__
#include "matmul_xdnn.h"
template <typename T>
infiniopStatus_t matmulKunlunCommon(infiniopMatmulKunlunDescriptor_t desc,
void *c,
float beta,
void const *a,
void const *b,
float alpha,
void *stream) {
auto info = desc->info;
if (info.is_transed) {
std::swap(a, b);
}
auto transA = info.a_matrix.col_stride == 1 ? false : true;
auto transB = info.b_matrix.col_stride == 1 ? false : true;
auto ret = use_xdnn(desc->xdnn_handle_pool,
(XPUStream)stream,
[&](xdnnHandle_t handle) {
for (size_t i = 0; i < info.batch; i++) {
CHECK_KUNLUN((
xdnn::fc_fusion<T, T, T, int16_t>(
handle,
(T *)((char *)a + i * info.a_matrix.stride * infiniSizeof(desc->dtype)),
(T *)((char *)b + i * info.b_matrix.stride * infiniSizeof(desc->dtype)),
(T *)((char *)c + i * info.c_matrix.stride * infiniSizeof(desc->dtype)),
info.m,
info.n,
info.k,
transA,
transB,
nullptr,
nullptr,
nullptr,
info.a_matrix.ld(),
info.b_matrix.ld(),
info.c_matrix.ld(),
alpha,
beta,
nullptr,
xdnn::Activation_t::LINEAR,
nullptr)));
}
return INFINIOP_STATUS_SUCCESS;
});
return ret;
}
infiniopStatus_t kunlunCreateMatmulDescriptor(infiniopKunlunHandle_t handle,
infiniopMatmulKunlunDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
infiniDtype_t dtype = c_desc->dtype;
if (dtype != INFINI_DTYPE_F16 && dtype != INFINI_DTYPE_F32) {
return INFINIOP_STATUS_BAD_TENSOR_DTYPE;
}
infiniopStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, false);
if (status != INFINIOP_STATUS_SUCCESS) {
return status;
}
*desc_ptr = new InfiniopMatmulKunlunDescriptor{
INFINI_DEVICE_KUNLUN,
dtype,
handle->device_id,
info,
handle->xdnn_handle_pool};
return INFINIOP_STATUS_SUCCESS;
}
infiniopStatus_t kunlunGetMatmulWorkspaceSize(infiniopMatmulKunlunDescriptor_t desc,
size_t *size) {
*size = 0;
return INFINIOP_STATUS_SUCCESS;
}
infiniopStatus_t kunlunMatmul(infiniopMatmulKunlunDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c,
void const *a,
void const *b,
float alpha,
float beta,
void *stream) {
if (desc->dtype == INFINI_DTYPE_F16) {
return matmulKunlunCommon<float16>(desc, c, beta, a, b, alpha, stream);
}
if (desc->dtype == INFINI_DTYPE_F32) {
return matmulKunlunCommon<float>(desc, c, beta, a, b, alpha, stream);
}
return INFINIOP_STATUS_BAD_TENSOR_DTYPE;
}
infiniopStatus_t kunlunDestroyMatmulDescriptor(infiniopMatmulKunlunDescriptor_t desc) {
desc->xdnn_handle_pool = nullptr;
delete desc;
return INFINIOP_STATUS_SUCCESS;
}
#ifndef __MATMUL_XDNN_H__
#define __MATMUL_XDNN_H__
#include "../../../devices/kunlun/common_kunlun.h"
#include "../../utils.h"
#include "../blas.h"
#include "matmul_xdnn_api.h"
struct InfiniopMatmulKunlunDescriptor {
infiniDevice_t device;
infiniDtype_t dtype;
int device_id;
MatmulInfo info;
std::shared_ptr<Pool<xdnnHandle_t>> xdnn_handle_pool;
};
#endif
#ifndef __MATMUL_XDNN_API_H__
#define __MATMUL_XDNN_API_H__
#include "../../../devices/kunlun/kunlun_handle.h"
#include "infiniop/operator.h"
struct InfiniopMatmulKunlunDescriptor;
typedef struct InfiniopMatmulKunlunDescriptor *infiniopMatmulKunlunDescriptor_t;
infiniopStatus_t kunlunCreateMatmulDescriptor(infiniopKunlunHandle_t handle,
infiniopMatmulKunlunDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc);
infiniopStatus_t kunlunGetMatmulWorkspaceSize(infiniopMatmulKunlunDescriptor_t desc,
size_t *size);
infiniopStatus_t kunlunMatmul(infiniopMatmulKunlunDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c,
void const *a,
void const *b,
float alpha,
float beta,
void *stream);
infiniopStatus_t kunlunDestroyMatmulDescriptor(infiniopMatmulKunlunDescriptor_t desc);
#endif
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#include "ascend/matmul_ascend.h" #include "ascend/matmul_ascend.h"
#endif #endif
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
#include "kunlun/matmul_xdnn_api.h" #include "kunlun/matmul_kunlun.h"
#endif #endif
__C infiniStatus_t infiniopCreateMatmulDescriptor( __C infiniStatus_t infiniopCreateMatmulDescriptor(
...@@ -46,6 +46,9 @@ __C infiniStatus_t infiniopCreateMatmulDescriptor( ...@@ -46,6 +46,9 @@ __C infiniStatus_t infiniopCreateMatmulDescriptor(
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
CREATE(INFINI_DEVICE_ASCEND, ascend); CREATE(INFINI_DEVICE_ASCEND, ascend);
#endif #endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -78,6 +81,9 @@ infiniopGetMatmulWorkspaceSize( ...@@ -78,6 +81,9 @@ infiniopGetMatmulWorkspaceSize(
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
GET(INFINI_DEVICE_ASCEND, ascend); GET(INFINI_DEVICE_ASCEND, ascend);
#endif #endif
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -118,6 +124,9 @@ __C infiniStatus_t infiniopMatmul( ...@@ -118,6 +124,9 @@ __C infiniStatus_t infiniopMatmul(
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
CALCULATE(INFINI_DEVICE_ASCEND, ascend); CALCULATE(INFINI_DEVICE_ASCEND, ascend);
#endif #endif
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -148,6 +157,9 @@ infiniopDestroyMatmulDescriptor(infiniopMatmulDescriptor_t desc) { ...@@ -148,6 +157,9 @@ infiniopDestroyMatmulDescriptor(infiniopMatmulDescriptor_t desc) {
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
DELETE(INFINI_DEVICE_ASCEND, ascend); DELETE(INFINI_DEVICE_ASCEND, ascend);
#endif #endif
#ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment