Commit a9acf208 authored by xgqdut2016's avatar xgqdut2016 Committed by zhangyue
Browse files

issue/340: kunlun cublas gemm

parent cb06c721
......@@ -13,5 +13,6 @@ typedef XPUEvent kunlunEvent_t;
typedef xdnn::Context *xdnnHandle_t;
#define CHECK_KUNLUN(API) CHECK_INTERNAL(API, XPU_SUCCESS)
#define CHECK_CUBLAS(API) CHECK_INTERNAL(API, CUBLAS_STATUS_SUCCESS)
#endif
......@@ -12,6 +12,17 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & {
infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
*handle_ptr = new Handle(device_id);
}
infiniStatus_t Handle::Internal::useCublas(cudaStream_t stream, const Fn<cublasHandle_t> &f) const {
auto handle = blas_handles.pop();
if (!handle) {
CHECK_CUBLAS(cublasCreate(&(*handle)));
}
CHECK_CUBLAS(cublasSetStream(*handle, stream));
CHECK_STATUS(f(*handle));
blas_handles.push(std::move(*handle));
return INFINI_STATUS_SUCCESS;
}
......
......@@ -23,11 +23,13 @@ public:
class Handle::Internal {
Pool<xdnnHandle_t> dnn_handles;
Pool<cublasHandle_t> blas_handles;
template <typename T>
using Fn = std::function<infiniStatus_t(T)>;
public:
infiniStatus_t useXdnn(kunlunStream_t stream, const Fn<xdnnHandle_t> &f) const;
infiniStatus_t useCublas(cudaStream_t stream, const Fn<cublasHandle_t> &f) const;
};
} // namespace device::kunlun
......
......@@ -38,6 +38,58 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_SUCCESS;
}
// template <class Tdata>
// infiniStatus_t calculate(
// MatmulInfo info,
// std::shared_ptr<HandleInternal> internal,
// infiniDtype_t dtype,
// void *c,
// float beta,
// const void *a,
// const void *b,
// float alpha,
// kunlunStream_t stream) {
// if (info.is_transed) {
// std::swap(a, b);
// }
// auto transA = info.a_matrix.col_stride == 1 ? false : true;
// auto transB = info.b_matrix.col_stride == 1 ? false : true;
// auto unit = infiniSizeOf(dtype);
// CHECK_STATUS(internal->useXdnn(
// (kunlunStream_t)stream,
// [&](xdnnHandle_t handle) {
// for (size_t i = 0; i < info.batch; i++) {
// CHECK_KUNLUN((xdnn::fc_fusion<Tdata, Tdata, Tdata, int16_t>(
// handle,
// (Tdata *)((char *)a + i * info.a_matrix.stride * unit),
// (Tdata *)((char *)b + i * info.b_matrix.stride * unit),
// (Tdata *)((char *)c + i * info.c_matrix.stride * unit),
// info.m,
// info.n,
// info.k,
// transA,
// transB,
// nullptr,
// nullptr,
// nullptr,
// info.a_matrix.ld(),
// info.b_matrix.ld(),
// info.c_matrix.ld(),
// alpha,
// beta,
// nullptr,
// xdnn::Activation_t::LINEAR,
// nullptr)));
// }
// return INFINI_STATUS_SUCCESS;
// }));
// return INFINI_STATUS_SUCCESS;
// }
template <class Tdata>
infiniStatus_t calculate(
MatmulInfo info,
......@@ -54,37 +106,56 @@ infiniStatus_t calculate(
std::swap(a, b);
}
auto transA = info.a_matrix.col_stride == 1 ? false : true;
auto transB = info.b_matrix.col_stride == 1 ? false : true;
auto transA = info.a_matrix.col_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
auto transB = info.b_matrix.col_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
cudaDataType_t a_type, b_type, c_type;
cublasComputeType_t compute_type;
switch (dtype) {
case INFINI_DTYPE_F16:
a_type = b_type = c_type = CUDA_R_16F;
compute_type = CUBLAS_COMPUTE_32F;
break;
case INFINI_DTYPE_BF16:
a_type = b_type = c_type = CUDA_R_16BF;
compute_type = CUBLAS_COMPUTE_32F;
break;
case INFINI_DTYPE_F32:
a_type = b_type = c_type = CUDA_R_32F;
compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
break;
auto unit = infiniSizeOf(dtype);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
CHECK_STATUS(internal->useXdnn(
(kunlunStream_t)stream,
[&](xdnnHandle_t handle) {
for (size_t i = 0; i < info.batch; i++) {
CHECK_KUNLUN((xdnn::fc_fusion<Tdata, Tdata, Tdata, int16_t>(
CHECK_STATUS(internal->useCublas(
(cudaStream_t)stream,
[&](cublasHandle_t handle) {
CHECK_CUBLAS(
cublasGemmStridedBatchedEx(
handle,
(Tdata *)((char *)a + i * info.a_matrix.stride * unit),
(Tdata *)((char *)b + i * info.b_matrix.stride * unit),
(Tdata *)((char *)c + i * info.c_matrix.stride * unit),
info.m,
info.n,
info.k,
transA,
transB,
nullptr,
nullptr,
nullptr,
info.a_matrix.ld(),
info.b_matrix.ld(),
info.c_matrix.ld(),
alpha,
beta,
nullptr,
xdnn::Activation_t::LINEAR,
nullptr)));
}
static_cast<int>(info.m),
static_cast<int>(info.n),
static_cast<int>(info.k),
&alpha,
a,
a_type,
static_cast<int>(info.a_matrix.ld()),
info.a_matrix.stride,
b,
b_type,
static_cast<int>(info.b_matrix.ld()),
info.b_matrix.stride,
&beta,
c,
c_type,
static_cast<int>(info.c_matrix.ld()),
info.c_matrix.stride,
static_cast<int>(info.batch),
compute_type,
CUBLAS_GEMM_DEFAULT));
return INFINI_STATUS_SUCCESS;
}));
return INFINI_STATUS_SUCCESS;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment