Unverified Commit 1d064392 authored by zhangyue's avatar zhangyue Committed by GitHub
Browse files

Merge pull request #409 from InfiniTensor/issue/340

Issue/340 接入昆仑芯 XBLAS
parents e221916d b92ecc31
#include "kunlun_xblas.h"
namespace device::kunlun::blas {
Handle::Handle(int device_id)
: InfiniopHandle{INFINI_DEVICE_KUNLUN, device_id},
_internal(std::make_shared<Handle::Internal>()) {}
auto Handle::internal() const -> const std::shared_ptr<Internal> & {
return _internal;
}
infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
*handle_ptr = new Handle(device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Handle::Internal::useCublas(cudaStream_t stream, const Fn<cublasHandle_t> &f) const {
auto handle = blas_handles.pop();
if (!handle) {
CHECK_CUBLAS(cublasCreate(&(*handle)));
}
CHECK_CUBLAS(cublasSetStream(*handle, stream));
CHECK_STATUS(f(*handle));
blas_handles.push(std::move(*handle));
return INFINI_STATUS_SUCCESS;
}
} // namespace device::kunlun::blas
#ifndef __KUNLUN_XBLAS_H__
#define __KUNLUN_XBLAS_H__
#include "../../handle.h"
#include "../pool.h"
#include "kunlun_common.h"
#include <cublas_v2.h>
#include <memory>
#define CHECK_CUBLAS(API) CHECK_INTERNAL(API, CUBLAS_STATUS_SUCCESS)
namespace device::kunlun::blas {
struct Handle : public InfiniopHandle {
class Internal;
auto internal() const -> const std::shared_ptr<Internal> &;
Handle(int device_id);
private:
std::shared_ptr<Internal> _internal;
public:
static infiniStatus_t create(InfiniopHandle **handle_ptr, int device_id);
};
class Handle::Internal {
Pool<cublasHandle_t> blas_handles;
template <typename T>
using Fn = std::function<infiniStatus_t(T)>;
public:
infiniStatus_t useCublas(cudaStream_t stream, const Fn<cublasHandle_t> &f) const;
};
} // namespace device::kunlun::blas
#endif // __KUNLUN_XBLAS_H__
#include "gemm_kunlun.h" #include "gemm_kunlun.h"
#include "../../../../utils.h"
#include "../../../devices/kunlun/kunlun_common.h" #include "../../../devices/kunlun/kunlun_common.h"
#include "../../../devices/kunlun/kunlun_handle.h" #include "../../../devices/kunlun/kunlun_xblas.h"
namespace op::gemm::kunlun { namespace op::gemm::kunlun {
typedef device::kunlun::Handle::Internal HandleInternal; typedef device::kunlun::blas::Handle::Internal HandleInternal;
struct Descriptor::Opaque { struct Descriptor::Opaque {
std::shared_ptr<HandleInternal> internal; std::shared_ptr<HandleInternal> internal;
...@@ -21,14 +20,12 @@ infiniStatus_t Descriptor::create( ...@@ -21,14 +20,12 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) { infiniopTensorDescriptor_t b_desc) {
auto handle = reinterpret_cast<device::kunlun::Handle *>(handle_); auto handle = reinterpret_cast<device::kunlun::blas::Handle *>(handle_);
auto dtype = c_desc->dtype(); auto dtype = c_desc->dtype();
if (dtype != INFINI_DTYPE_F16 && dtype != INFINI_DTYPE_F32) { CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::ROW_MAJOR); auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
CHECK_RESULT(result); CHECK_RESULT(result);
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
...@@ -38,75 +35,74 @@ infiniStatus_t Descriptor::create( ...@@ -38,75 +35,74 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
template <class Tdata>
infiniStatus_t calculate(
MatmulInfo info,
std::shared_ptr<HandleInternal> internal,
infiniDtype_t dtype,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
kunlunStream_t stream) {
if (info.is_transed) {
std::swap(a, b);
}
auto transA = info.a_matrix.col_stride == 1 ? false : true;
auto transB = info.b_matrix.col_stride == 1 ? false : true;
auto unit = infiniSizeOf(dtype);
CHECK_STATUS(internal->useXdnn(
(kunlunStream_t)stream,
[&](xdnnHandle_t handle) {
for (size_t i = 0; i < info.batch; i++) {
CHECK_KUNLUN((xdnn::fc_fusion<Tdata, Tdata, Tdata, int16_t>(
handle,
(Tdata *)((char *)a + i * info.a_matrix.stride * unit),
(Tdata *)((char *)b + i * info.b_matrix.stride * unit),
(Tdata *)((char *)c + i * info.c_matrix.stride * unit),
info.m,
info.n,
info.k,
transA,
transB,
nullptr,
nullptr,
nullptr,
info.a_matrix.ld(),
info.b_matrix.ld(),
info.c_matrix.ld(),
alpha,
beta,
nullptr,
xdnn::Activation_t::LINEAR,
nullptr)));
}
return INFINI_STATUS_SUCCESS;
}));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate( infiniStatus_t Descriptor::calculate(
void *workspace, void *workspace,
size_t worksapce_size, size_t workspace_size,
void *c, void *c,
float beta, float beta,
const void *a, const void *a,
const void *b, const void *b,
float alpha, float alpha,
void *stream) const { void *stream) const {
cudaDataType a_type, b_type, c_type;
cublasComputeType_t compute_type;
switch (_dtype) { switch (_dtype) {
case INFINI_DTYPE_F16: case INFINI_DTYPE_F16:
return op::gemm::kunlun::calculate<float16>(_info, _opaque->internal, _dtype, c, beta, a, b, alpha, (kunlunStream_t)stream); a_type = b_type = c_type = CUDA_R_16F;
compute_type = CUBLAS_COMPUTE_32F;
break;
case INFINI_DTYPE_BF16:
a_type = b_type = c_type = CUDA_R_16BF;
compute_type = CUBLAS_COMPUTE_32F;
break;
case INFINI_DTYPE_F32: case INFINI_DTYPE_F32:
return op::gemm::kunlun::calculate<float>(_info, _opaque->internal, _dtype, c, beta, a, b, alpha, (kunlunStream_t)stream); a_type = b_type = c_type = CUDA_R_32F;
compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
break;
default: default:
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
if (_info.is_transed) {
std::swap(a, b);
}
auto op_a = _info.a_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
auto op_b = _info.b_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
CHECK_STATUS(_opaque->internal->useCublas(
(cudaStream_t)stream,
[&](cublasHandle_t handle) {
CHECK_CUBLAS(
cublasGemmStridedBatchedEx(
handle,
op_a,
op_b,
static_cast<int>(_info.m),
static_cast<int>(_info.n),
static_cast<int>(_info.k),
&alpha,
a,
a_type,
static_cast<int>(_info.a_matrix.ld()),
_info.a_matrix.stride,
b,
b_type,
static_cast<int>(_info.b_matrix.ld()),
_info.b_matrix.stride,
&beta,
c,
c_type,
static_cast<int>(_info.c_matrix.ld()),
_info.c_matrix.stride,
static_cast<int>(_info.batch),
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
return INFINI_STATUS_SUCCESS;
}));
return INFINI_STATUS_SUCCESS;
} }
} // namespace op::gemm::kunlun } // namespace op::gemm::kunlun
...@@ -3,16 +3,19 @@ local KUNLUN_HOME = os.getenv("KUNLUN_HOME") ...@@ -3,16 +3,19 @@ local KUNLUN_HOME = os.getenv("KUNLUN_HOME")
local XRE_DIR = path.join(KUNLUN_HOME, "xre") local XRE_DIR = path.join(KUNLUN_HOME, "xre")
local XTDK_DIR = path.join(KUNLUN_HOME, "xtdk") local XTDK_DIR = path.join(KUNLUN_HOME, "xtdk")
local XDNN_DIR = path.join(KUNLUN_HOME, "xhpc", "xdnn") local XDNN_DIR = path.join(KUNLUN_HOME, "xhpc", "xdnn")
local XBLAS_DIR = path.join(KUNLUN_HOME, "xhpc", "xblas")
-- Add include dirs -- Add include dirs
add_includedirs(path.join(XRE_DIR, "include"), {public = true}) add_includedirs(path.join(XRE_DIR, "include"), {public = true})
add_includedirs(path.join(XDNN_DIR, "include"), {public = true}) add_includedirs(path.join(XDNN_DIR, "include"), {public = true})
add_includedirs(path.join(XTDK_DIR, "include"), {public = true}) add_includedirs(path.join(XTDK_DIR, "include"), {public = true})
add_includedirs(path.join(XBLAS_DIR, "include"), {public = true})
-- Add link dirs -- Add link dirs
add_linkdirs(path.join(XRE_DIR, "so")) add_linkdirs(path.join(XRE_DIR, "so"))
add_linkdirs(path.join(XDNN_DIR, "so")) add_linkdirs(path.join(XDNN_DIR, "so"))
add_links("xpurt", "xpuapi") add_linkdirs(path.join(XBLAS_DIR, "so"))
add_links("xpurt", "xpuapi", "xpu_blas")
rule("xpu") rule("xpu")
set_extensions(".xpu") set_extensions(".xpu")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment