Commit b92ecc31 authored by zhangyue's avatar zhangyue
Browse files

issue/340: xblas

parent a9acf208
......@@ -13,6 +13,5 @@ typedef XPUEvent kunlunEvent_t;
typedef xdnn::Context *xdnnHandle_t;
#define CHECK_KUNLUN(API) CHECK_INTERNAL(API, XPU_SUCCESS)
#define CHECK_CUBLAS(API) CHECK_INTERNAL(API, CUBLAS_STATUS_SUCCESS)
#endif
......@@ -12,17 +12,6 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & {
infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
*handle_ptr = new Handle(device_id);
}
infiniStatus_t Handle::Internal::useCublas(cudaStream_t stream, const Fn<cublasHandle_t> &f) const {
auto handle = blas_handles.pop();
if (!handle) {
CHECK_CUBLAS(cublasCreate(&(*handle)));
}
CHECK_CUBLAS(cublasSetStream(*handle, stream));
CHECK_STATUS(f(*handle));
blas_handles.push(std::move(*handle));
return INFINI_STATUS_SUCCESS;
}
......
......@@ -23,13 +23,11 @@ public:
class Handle::Internal {
Pool<xdnnHandle_t> dnn_handles;
Pool<cublasHandle_t> blas_handles;
template <typename T>
using Fn = std::function<infiniStatus_t(T)>;
public:
infiniStatus_t useXdnn(kunlunStream_t stream, const Fn<xdnnHandle_t> &f) const;
infiniStatus_t useCublas(cudaStream_t stream, const Fn<cublasHandle_t> &f) const;
};
} // namespace device::kunlun
......
#include "kunlun_xblas.h"
namespace device::kunlun::blas {
Handle::Handle(int device_id)
: InfiniopHandle{INFINI_DEVICE_KUNLUN, device_id},
_internal(std::make_shared<Handle::Internal>()) {}
auto Handle::internal() const -> const std::shared_ptr<Internal> & {
return _internal;
}
infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
*handle_ptr = new Handle(device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Handle::Internal::useCublas(cudaStream_t stream, const Fn<cublasHandle_t> &f) const {
auto handle = blas_handles.pop();
if (!handle) {
CHECK_CUBLAS(cublasCreate(&(*handle)));
}
CHECK_CUBLAS(cublasSetStream(*handle, stream));
CHECK_STATUS(f(*handle));
blas_handles.push(std::move(*handle));
return INFINI_STATUS_SUCCESS;
}
} // namespace device::kunlun::blas
#ifndef __KUNLUN_XBLAS_H__
#define __KUNLUN_XBLAS_H__
#include "../../handle.h"
#include "../pool.h"
#include "kunlun_common.h"
#include <cublas_v2.h>
#include <memory>
#define CHECK_CUBLAS(API) CHECK_INTERNAL(API, CUBLAS_STATUS_SUCCESS)
namespace device::kunlun::blas {
struct Handle : public InfiniopHandle {
class Internal;
auto internal() const -> const std::shared_ptr<Internal> &;
Handle(int device_id);
private:
std::shared_ptr<Internal> _internal;
public:
static infiniStatus_t create(InfiniopHandle **handle_ptr, int device_id);
};
class Handle::Internal {
Pool<cublasHandle_t> blas_handles;
template <typename T>
using Fn = std::function<infiniStatus_t(T)>;
public:
infiniStatus_t useCublas(cudaStream_t stream, const Fn<cublasHandle_t> &f) const;
};
} // namespace device::kunlun::blas
#endif // __KUNLUN_XBLAS_H__
#include "gemm_kunlun.h"
#include "../../../../utils.h"
#include "../../../devices/kunlun/kunlun_common.h"
#include "../../../devices/kunlun/kunlun_handle.h"
#include "../../../devices/kunlun/kunlun_xblas.h"
namespace op::gemm::kunlun {
typedef device::kunlun::Handle::Internal HandleInternal;
typedef device::kunlun::blas::Handle::Internal HandleInternal;
struct Descriptor::Opaque {
std::shared_ptr<HandleInternal> internal;
......@@ -21,14 +20,12 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
auto handle = reinterpret_cast<device::kunlun::Handle *>(handle_);
auto handle = reinterpret_cast<device::kunlun::blas::Handle *>(handle_);
auto dtype = c_desc->dtype();
if (dtype != INFINI_DTYPE_F16 && dtype != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::ROW_MAJOR);
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(
......@@ -38,79 +35,20 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_SUCCESS;
}
// template <class Tdata>
// infiniStatus_t calculate(
// MatmulInfo info,
// std::shared_ptr<HandleInternal> internal,
// infiniDtype_t dtype,
// void *c,
// float beta,
// const void *a,
// const void *b,
// float alpha,
// kunlunStream_t stream) {
// if (info.is_transed) {
// std::swap(a, b);
// }
// auto transA = info.a_matrix.col_stride == 1 ? false : true;
// auto transB = info.b_matrix.col_stride == 1 ? false : true;
// auto unit = infiniSizeOf(dtype);
// CHECK_STATUS(internal->useXdnn(
// (kunlunStream_t)stream,
// [&](xdnnHandle_t handle) {
// for (size_t i = 0; i < info.batch; i++) {
// CHECK_KUNLUN((xdnn::fc_fusion<Tdata, Tdata, Tdata, int16_t>(
// handle,
// (Tdata *)((char *)a + i * info.a_matrix.stride * unit),
// (Tdata *)((char *)b + i * info.b_matrix.stride * unit),
// (Tdata *)((char *)c + i * info.c_matrix.stride * unit),
// info.m,
// info.n,
// info.k,
// transA,
// transB,
// nullptr,
// nullptr,
// nullptr,
// info.a_matrix.ld(),
// info.b_matrix.ld(),
// info.c_matrix.ld(),
// alpha,
// beta,
// nullptr,
// xdnn::Activation_t::LINEAR,
// nullptr)));
// }
// return INFINI_STATUS_SUCCESS;
// }));
// return INFINI_STATUS_SUCCESS;
// }
template <class Tdata>
infiniStatus_t calculate(
MatmulInfo info,
std::shared_ptr<HandleInternal> internal,
infiniDtype_t dtype,
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
kunlunStream_t stream) {
if (info.is_transed) {
std::swap(a, b);
}
void *stream) const {
auto transA = info.a_matrix.col_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
auto transB = info.b_matrix.col_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
cudaDataType_t a_type, b_type, c_type;
cudaDataType a_type, b_type, c_type;
cublasComputeType_t compute_type;
switch (dtype) {
switch (_dtype) {
case INFINI_DTYPE_F16:
a_type = b_type = c_type = CUDA_R_16F;
compute_type = CUBLAS_COMPUTE_32F;
......@@ -123,61 +61,48 @@ infiniStatus_t calculate(
a_type = b_type = c_type = CUDA_R_32F;
compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
break;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
CHECK_STATUS(internal->useCublas(
if (_info.is_transed) {
std::swap(a, b);
}
auto op_a = _info.a_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
auto op_b = _info.b_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
CHECK_STATUS(_opaque->internal->useCublas(
(cudaStream_t)stream,
[&](cublasHandle_t handle) {
CHECK_CUBLAS(
cublasGemmStridedBatchedEx(
handle,
transA,
transB,
static_cast<int>(info.m),
static_cast<int>(info.n),
static_cast<int>(info.k),
op_a,
op_b,
static_cast<int>(_info.m),
static_cast<int>(_info.n),
static_cast<int>(_info.k),
&alpha,
a,
a_type,
static_cast<int>(info.a_matrix.ld()),
info.a_matrix.stride,
static_cast<int>(_info.a_matrix.ld()),
_info.a_matrix.stride,
b,
b_type,
static_cast<int>(info.b_matrix.ld()),
info.b_matrix.stride,
static_cast<int>(_info.b_matrix.ld()),
_info.b_matrix.stride,
&beta,
c,
c_type,
static_cast<int>(info.c_matrix.ld()),
info.c_matrix.stride,
static_cast<int>(info.batch),
static_cast<int>(_info.c_matrix.ld()),
_info.c_matrix.stride,
static_cast<int>(_info.batch),
compute_type,
CUBLAS_GEMM_DEFAULT));
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
return INFINI_STATUS_SUCCESS;
}));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t worksapce_size,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
void *stream) const {
switch (_dtype) {
case INFINI_DTYPE_F16:
return op::gemm::kunlun::calculate<float16>(_info, _opaque->internal, _dtype, c, beta, a, b, alpha, (kunlunStream_t)stream);
case INFINI_DTYPE_F32:
return op::gemm::kunlun::calculate<float>(_info, _opaque->internal, _dtype, c, beta, a, b, alpha, (kunlunStream_t)stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
} // namespace op::gemm::kunlun
......@@ -3,16 +3,19 @@ local KUNLUN_HOME = os.getenv("KUNLUN_HOME")
local XRE_DIR = path.join(KUNLUN_HOME, "xre")
local XTDK_DIR = path.join(KUNLUN_HOME, "xtdk")
local XDNN_DIR = path.join(KUNLUN_HOME, "xhpc", "xdnn")
local XBLAS_DIR = path.join(KUNLUN_HOME, "xhpc", "xblas")
-- Add include dirs
add_includedirs(path.join(XRE_DIR, "include"), {public = true})
add_includedirs(path.join(XDNN_DIR, "include"), {public = true})
add_includedirs(path.join(XTDK_DIR, "include"), {public = true})
add_includedirs(path.join(XBLAS_DIR, "include"), {public = true})
-- Add link dirs
add_linkdirs(path.join(XRE_DIR, "so"))
add_linkdirs(path.join(XDNN_DIR, "so"))
add_links("xpurt", "xpuapi")
add_linkdirs(path.join(XBLAS_DIR, "so"))
add_links("xpurt", "xpuapi", "xpu_blas")
rule("xpu")
set_extensions(".xpu")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment