Commit b92ecc31 authored by zhangyue's avatar zhangyue
Browse files

issue/340: xblas

parent a9acf208
...@@ -13,6 +13,5 @@ typedef XPUEvent kunlunEvent_t; ...@@ -13,6 +13,5 @@ typedef XPUEvent kunlunEvent_t;
typedef xdnn::Context *xdnnHandle_t; typedef xdnn::Context *xdnnHandle_t;
#define CHECK_KUNLUN(API) CHECK_INTERNAL(API, XPU_SUCCESS) #define CHECK_KUNLUN(API) CHECK_INTERNAL(API, XPU_SUCCESS)
#define CHECK_CUBLAS(API) CHECK_INTERNAL(API, CUBLAS_STATUS_SUCCESS)
#endif #endif
...@@ -12,17 +12,6 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & { ...@@ -12,17 +12,6 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & {
infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) { infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
*handle_ptr = new Handle(device_id); *handle_ptr = new Handle(device_id);
}
infiniStatus_t Handle::Internal::useCublas(cudaStream_t stream, const Fn<cublasHandle_t> &f) const {
auto handle = blas_handles.pop();
if (!handle) {
CHECK_CUBLAS(cublasCreate(&(*handle)));
}
CHECK_CUBLAS(cublasSetStream(*handle, stream));
CHECK_STATUS(f(*handle));
blas_handles.push(std::move(*handle));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
......
...@@ -23,13 +23,11 @@ public: ...@@ -23,13 +23,11 @@ public:
class Handle::Internal { class Handle::Internal {
Pool<xdnnHandle_t> dnn_handles; Pool<xdnnHandle_t> dnn_handles;
Pool<cublasHandle_t> blas_handles;
template <typename T> template <typename T>
using Fn = std::function<infiniStatus_t(T)>; using Fn = std::function<infiniStatus_t(T)>;
public: public:
infiniStatus_t useXdnn(kunlunStream_t stream, const Fn<xdnnHandle_t> &f) const; infiniStatus_t useXdnn(kunlunStream_t stream, const Fn<xdnnHandle_t> &f) const;
infiniStatus_t useCublas(cudaStream_t stream, const Fn<cublasHandle_t> &f) const;
}; };
} // namespace device::kunlun } // namespace device::kunlun
......
#include "kunlun_xblas.h"
namespace device::kunlun::blas {
Handle::Handle(int device_id)
: InfiniopHandle{INFINI_DEVICE_KUNLUN, device_id},
_internal(std::make_shared<Handle::Internal>()) {}
auto Handle::internal() const -> const std::shared_ptr<Internal> & {
return _internal;
}
infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
*handle_ptr = new Handle(device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Handle::Internal::useCublas(cudaStream_t stream, const Fn<cublasHandle_t> &f) const {
auto handle = blas_handles.pop();
if (!handle) {
CHECK_CUBLAS(cublasCreate(&(*handle)));
}
CHECK_CUBLAS(cublasSetStream(*handle, stream));
CHECK_STATUS(f(*handle));
blas_handles.push(std::move(*handle));
return INFINI_STATUS_SUCCESS;
}
} // namespace device::kunlun::blas
#ifndef __KUNLUN_XBLAS_H__
#define __KUNLUN_XBLAS_H__
#include "../../handle.h"
#include "../pool.h"
#include "kunlun_common.h"
#include <cublas_v2.h>
#include <memory>
#define CHECK_CUBLAS(API) CHECK_INTERNAL(API, CUBLAS_STATUS_SUCCESS)
namespace device::kunlun::blas {
struct Handle : public InfiniopHandle {
class Internal;
auto internal() const -> const std::shared_ptr<Internal> &;
Handle(int device_id);
private:
std::shared_ptr<Internal> _internal;
public:
static infiniStatus_t create(InfiniopHandle **handle_ptr, int device_id);
};
class Handle::Internal {
Pool<cublasHandle_t> blas_handles;
template <typename T>
using Fn = std::function<infiniStatus_t(T)>;
public:
infiniStatus_t useCublas(cudaStream_t stream, const Fn<cublasHandle_t> &f) const;
};
} // namespace device::kunlun::blas
#endif // __KUNLUN_XBLAS_H__
#include "gemm_kunlun.h" #include "gemm_kunlun.h"
#include "../../../../utils.h"
#include "../../../devices/kunlun/kunlun_common.h" #include "../../../devices/kunlun/kunlun_common.h"
#include "../../../devices/kunlun/kunlun_handle.h" #include "../../../devices/kunlun/kunlun_xblas.h"
namespace op::gemm::kunlun { namespace op::gemm::kunlun {
typedef device::kunlun::Handle::Internal HandleInternal; typedef device::kunlun::blas::Handle::Internal HandleInternal;
struct Descriptor::Opaque { struct Descriptor::Opaque {
std::shared_ptr<HandleInternal> internal; std::shared_ptr<HandleInternal> internal;
...@@ -21,14 +20,12 @@ infiniStatus_t Descriptor::create( ...@@ -21,14 +20,12 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) { infiniopTensorDescriptor_t b_desc) {
auto handle = reinterpret_cast<device::kunlun::Handle *>(handle_); auto handle = reinterpret_cast<device::kunlun::blas::Handle *>(handle_);
auto dtype = c_desc->dtype(); auto dtype = c_desc->dtype();
if (dtype != INFINI_DTYPE_F16 && dtype != INFINI_DTYPE_F32) { CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::ROW_MAJOR); auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
CHECK_RESULT(result); CHECK_RESULT(result);
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
...@@ -38,79 +35,20 @@ infiniStatus_t Descriptor::create( ...@@ -38,79 +35,20 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
// template <class Tdata> infiniStatus_t Descriptor::calculate(
// infiniStatus_t calculate( void *workspace,
// MatmulInfo info, size_t workspace_size,
// std::shared_ptr<HandleInternal> internal,
// infiniDtype_t dtype,
// void *c,
// float beta,
// const void *a,
// const void *b,
// float alpha,
// kunlunStream_t stream) {
// if (info.is_transed) {
// std::swap(a, b);
// }
// auto transA = info.a_matrix.col_stride == 1 ? false : true;
// auto transB = info.b_matrix.col_stride == 1 ? false : true;
// auto unit = infiniSizeOf(dtype);
// CHECK_STATUS(internal->useXdnn(
// (kunlunStream_t)stream,
// [&](xdnnHandle_t handle) {
// for (size_t i = 0; i < info.batch; i++) {
// CHECK_KUNLUN((xdnn::fc_fusion<Tdata, Tdata, Tdata, int16_t>(
// handle,
// (Tdata *)((char *)a + i * info.a_matrix.stride * unit),
// (Tdata *)((char *)b + i * info.b_matrix.stride * unit),
// (Tdata *)((char *)c + i * info.c_matrix.stride * unit),
// info.m,
// info.n,
// info.k,
// transA,
// transB,
// nullptr,
// nullptr,
// nullptr,
// info.a_matrix.ld(),
// info.b_matrix.ld(),
// info.c_matrix.ld(),
// alpha,
// beta,
// nullptr,
// xdnn::Activation_t::LINEAR,
// nullptr)));
// }
// return INFINI_STATUS_SUCCESS;
// }));
// return INFINI_STATUS_SUCCESS;
// }
template <class Tdata>
infiniStatus_t calculate(
MatmulInfo info,
std::shared_ptr<HandleInternal> internal,
infiniDtype_t dtype,
void *c, void *c,
float beta, float beta,
const void *a, const void *a,
const void *b, const void *b,
float alpha, float alpha,
kunlunStream_t stream) { void *stream) const {
if (info.is_transed) {
std::swap(a, b);
}
auto transA = info.a_matrix.col_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T; cudaDataType a_type, b_type, c_type;
auto transB = info.b_matrix.col_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
cudaDataType_t a_type, b_type, c_type;
cublasComputeType_t compute_type; cublasComputeType_t compute_type;
switch (dtype) {
switch (_dtype) {
case INFINI_DTYPE_F16: case INFINI_DTYPE_F16:
a_type = b_type = c_type = CUDA_R_16F; a_type = b_type = c_type = CUDA_R_16F;
compute_type = CUBLAS_COMPUTE_32F; compute_type = CUBLAS_COMPUTE_32F;
...@@ -123,61 +61,48 @@ infiniStatus_t calculate( ...@@ -123,61 +61,48 @@ infiniStatus_t calculate(
a_type = b_type = c_type = CUDA_R_32F; a_type = b_type = c_type = CUDA_R_32F;
compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
break; break;
default: default:
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
CHECK_STATUS(internal->useCublas( if (_info.is_transed) {
std::swap(a, b);
}
auto op_a = _info.a_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
auto op_b = _info.b_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
CHECK_STATUS(_opaque->internal->useCublas(
(cudaStream_t)stream, (cudaStream_t)stream,
[&](cublasHandle_t handle) { [&](cublasHandle_t handle) {
CHECK_CUBLAS( CHECK_CUBLAS(
cublasGemmStridedBatchedEx( cublasGemmStridedBatchedEx(
handle, handle,
transA, op_a,
transB, op_b,
static_cast<int>(info.m), static_cast<int>(_info.m),
static_cast<int>(info.n), static_cast<int>(_info.n),
static_cast<int>(info.k), static_cast<int>(_info.k),
&alpha, &alpha,
a, a,
a_type, a_type,
static_cast<int>(info.a_matrix.ld()), static_cast<int>(_info.a_matrix.ld()),
info.a_matrix.stride, _info.a_matrix.stride,
b, b,
b_type, b_type,
static_cast<int>(info.b_matrix.ld()), static_cast<int>(_info.b_matrix.ld()),
info.b_matrix.stride, _info.b_matrix.stride,
&beta, &beta,
c, c,
c_type, c_type,
static_cast<int>(info.c_matrix.ld()), static_cast<int>(_info.c_matrix.ld()),
info.c_matrix.stride, _info.c_matrix.stride,
static_cast<int>(info.batch), static_cast<int>(_info.batch),
compute_type, compute_type,
CUBLAS_GEMM_DEFAULT)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
})); }));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t worksapce_size,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
void *stream) const {
switch (_dtype) {
case INFINI_DTYPE_F16:
return op::gemm::kunlun::calculate<float16>(_info, _opaque->internal, _dtype, c, beta, a, b, alpha, (kunlunStream_t)stream);
case INFINI_DTYPE_F32:
return op::gemm::kunlun::calculate<float>(_info, _opaque->internal, _dtype, c, beta, a, b, alpha, (kunlunStream_t)stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
} // namespace op::gemm::kunlun } // namespace op::gemm::kunlun
...@@ -3,16 +3,19 @@ local KUNLUN_HOME = os.getenv("KUNLUN_HOME") ...@@ -3,16 +3,19 @@ local KUNLUN_HOME = os.getenv("KUNLUN_HOME")
local XRE_DIR = path.join(KUNLUN_HOME, "xre") local XRE_DIR = path.join(KUNLUN_HOME, "xre")
local XTDK_DIR = path.join(KUNLUN_HOME, "xtdk") local XTDK_DIR = path.join(KUNLUN_HOME, "xtdk")
local XDNN_DIR = path.join(KUNLUN_HOME, "xhpc", "xdnn") local XDNN_DIR = path.join(KUNLUN_HOME, "xhpc", "xdnn")
local XBLAS_DIR = path.join(KUNLUN_HOME, "xhpc", "xblas")
-- Add include dirs -- Add include dirs
add_includedirs(path.join(XRE_DIR, "include"), {public = true}) add_includedirs(path.join(XRE_DIR, "include"), {public = true})
add_includedirs(path.join(XDNN_DIR, "include"), {public = true}) add_includedirs(path.join(XDNN_DIR, "include"), {public = true})
add_includedirs(path.join(XTDK_DIR, "include"), {public = true}) add_includedirs(path.join(XTDK_DIR, "include"), {public = true})
add_includedirs(path.join(XBLAS_DIR, "include"), {public = true})
-- Add link dirs -- Add link dirs
add_linkdirs(path.join(XRE_DIR, "so")) add_linkdirs(path.join(XRE_DIR, "so"))
add_linkdirs(path.join(XDNN_DIR, "so")) add_linkdirs(path.join(XDNN_DIR, "so"))
add_links("xpurt", "xpuapi") add_linkdirs(path.join(XBLAS_DIR, "so"))
add_links("xpurt", "xpuapi", "xpu_blas")
rule("xpu") rule("xpu")
set_extensions(".xpu") set_extensions(".xpu")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment