Commit a9acf208 authored by xgqdut2016's avatar xgqdut2016 Committed by zhangyue
Browse files

issue/340: kunlun cublas gemm

parent cb06c721
...@@ -13,5 +13,6 @@ typedef XPUEvent kunlunEvent_t; ...@@ -13,5 +13,6 @@ typedef XPUEvent kunlunEvent_t;
typedef xdnn::Context *xdnnHandle_t; typedef xdnn::Context *xdnnHandle_t;
#define CHECK_KUNLUN(API) CHECK_INTERNAL(API, XPU_SUCCESS) #define CHECK_KUNLUN(API) CHECK_INTERNAL(API, XPU_SUCCESS)
#define CHECK_CUBLAS(API) CHECK_INTERNAL(API, CUBLAS_STATUS_SUCCESS)
#endif #endif
...@@ -12,6 +12,17 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & { ...@@ -12,6 +12,17 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & {
infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) { infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
*handle_ptr = new Handle(device_id); *handle_ptr = new Handle(device_id);
}
infiniStatus_t Handle::Internal::useCublas(cudaStream_t stream, const Fn<cublasHandle_t> &f) const {
auto handle = blas_handles.pop();
if (!handle) {
CHECK_CUBLAS(cublasCreate(&(*handle)));
}
CHECK_CUBLAS(cublasSetStream(*handle, stream));
CHECK_STATUS(f(*handle));
blas_handles.push(std::move(*handle));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
......
...@@ -23,11 +23,13 @@ public: ...@@ -23,11 +23,13 @@ public:
class Handle::Internal { class Handle::Internal {
Pool<xdnnHandle_t> dnn_handles; Pool<xdnnHandle_t> dnn_handles;
Pool<cublasHandle_t> blas_handles;
template <typename T> template <typename T>
using Fn = std::function<infiniStatus_t(T)>; using Fn = std::function<infiniStatus_t(T)>;
public: public:
infiniStatus_t useXdnn(kunlunStream_t stream, const Fn<xdnnHandle_t> &f) const; infiniStatus_t useXdnn(kunlunStream_t stream, const Fn<xdnnHandle_t> &f) const;
infiniStatus_t useCublas(cudaStream_t stream, const Fn<cublasHandle_t> &f) const;
}; };
} // namespace device::kunlun } // namespace device::kunlun
......
...@@ -38,6 +38,58 @@ infiniStatus_t Descriptor::create( ...@@ -38,6 +38,58 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
// template <class Tdata>
// infiniStatus_t calculate(
// MatmulInfo info,
// std::shared_ptr<HandleInternal> internal,
// infiniDtype_t dtype,
// void *c,
// float beta,
// const void *a,
// const void *b,
// float alpha,
// kunlunStream_t stream) {
// if (info.is_transed) {
// std::swap(a, b);
// }
// auto transA = info.a_matrix.col_stride == 1 ? false : true;
// auto transB = info.b_matrix.col_stride == 1 ? false : true;
// auto unit = infiniSizeOf(dtype);
// CHECK_STATUS(internal->useXdnn(
// (kunlunStream_t)stream,
// [&](xdnnHandle_t handle) {
// for (size_t i = 0; i < info.batch; i++) {
// CHECK_KUNLUN((xdnn::fc_fusion<Tdata, Tdata, Tdata, int16_t>(
// handle,
// (Tdata *)((char *)a + i * info.a_matrix.stride * unit),
// (Tdata *)((char *)b + i * info.b_matrix.stride * unit),
// (Tdata *)((char *)c + i * info.c_matrix.stride * unit),
// info.m,
// info.n,
// info.k,
// transA,
// transB,
// nullptr,
// nullptr,
// nullptr,
// info.a_matrix.ld(),
// info.b_matrix.ld(),
// info.c_matrix.ld(),
// alpha,
// beta,
// nullptr,
// xdnn::Activation_t::LINEAR,
// nullptr)));
// }
// return INFINI_STATUS_SUCCESS;
// }));
// return INFINI_STATUS_SUCCESS;
// }
template <class Tdata> template <class Tdata>
infiniStatus_t calculate( infiniStatus_t calculate(
MatmulInfo info, MatmulInfo info,
...@@ -54,37 +106,56 @@ infiniStatus_t calculate( ...@@ -54,37 +106,56 @@ infiniStatus_t calculate(
std::swap(a, b); std::swap(a, b);
} }
auto transA = info.a_matrix.col_stride == 1 ? false : true; auto transA = info.a_matrix.col_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
auto transB = info.b_matrix.col_stride == 1 ? false : true; auto transB = info.b_matrix.col_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
cudaDataType_t a_type, b_type, c_type;
cublasComputeType_t compute_type;
switch (dtype) {
case INFINI_DTYPE_F16:
a_type = b_type = c_type = CUDA_R_16F;
compute_type = CUBLAS_COMPUTE_32F;
break;
case INFINI_DTYPE_BF16:
a_type = b_type = c_type = CUDA_R_16BF;
compute_type = CUBLAS_COMPUTE_32F;
break;
case INFINI_DTYPE_F32:
a_type = b_type = c_type = CUDA_R_32F;
compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
break;
auto unit = infiniSizeOf(dtype); default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
CHECK_STATUS(internal->useXdnn( CHECK_STATUS(internal->useCublas(
(kunlunStream_t)stream, (cudaStream_t)stream,
[&](xdnnHandle_t handle) { [&](cublasHandle_t handle) {
for (size_t i = 0; i < info.batch; i++) { CHECK_CUBLAS(
CHECK_KUNLUN((xdnn::fc_fusion<Tdata, Tdata, Tdata, int16_t>( cublasGemmStridedBatchedEx(
handle, handle,
(Tdata *)((char *)a + i * info.a_matrix.stride * unit),
(Tdata *)((char *)b + i * info.b_matrix.stride * unit),
(Tdata *)((char *)c + i * info.c_matrix.stride * unit),
info.m,
info.n,
info.k,
transA, transA,
transB, transB,
nullptr, static_cast<int>(info.m),
nullptr, static_cast<int>(info.n),
nullptr, static_cast<int>(info.k),
info.a_matrix.ld(), &alpha,
info.b_matrix.ld(), a,
info.c_matrix.ld(), a_type,
alpha, static_cast<int>(info.a_matrix.ld()),
beta, info.a_matrix.stride,
nullptr, b,
xdnn::Activation_t::LINEAR, b_type,
nullptr))); static_cast<int>(info.b_matrix.ld()),
} info.b_matrix.stride,
&beta,
c,
c_type,
static_cast<int>(info.c_matrix.ld()),
info.c_matrix.stride,
static_cast<int>(info.batch),
compute_type,
CUBLAS_GEMM_DEFAULT));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
})); }));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment