Unverified Commit f81666d2 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #26 from PanZezhong1725/kunlun

issue/25:昆仑芯 matmul
parents 45175dbf a563f3de
......@@ -11,6 +11,9 @@
#ifdef ENABLE_ASCEND_API
#include "ascend/ascend_handle.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/kunlun_handle.h"
#endif
__C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr,
infiniDevice_t device) {
......@@ -37,6 +40,11 @@ __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr,
case INFINI_DEVICE_ASCEND: {
return createAscendHandle((infiniopAscendHandle_t *)handle_ptr);
}
#endif
#ifdef ENABLE_KUNLUN_API
case INFINI_DEVICE_KUNLUN: {
return createKunlunHandle((infiniopKunlunHandle_t *)handle_ptr);
}
#endif
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -62,6 +70,11 @@ __C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) {
case INFINI_DEVICE_ASCEND: {
return destroyAscendHandle((infiniopAscendHandle_t)handle);
}
#endif
#ifdef ENABLE_KUNLUN_API
case INFINI_DEVICE_KUNLUN: {
return destroyKunlunHandle((infiniopKunlunHandle_t)handle);
}
#endif
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
#ifndef __INFINIOP_COMMON_KUNLUN_H__
#define __INFINIOP_COMMON_KUNLUN_H__
#include "../pool.h"
#include "infinicore.h"
#include "kunlun_handle.h"
#include "xpu/runtime.h"
#include "xpu/runtime_ex.h"
#include "xpu/xdnn.h"
#include <memory>
namespace xdnn = baidu::xpu::api;
typedef xdnn::Context *xdnnHandle_t;
typedef XPUStream KunlunStream_t;
#define CHECK_KUNLUN(call) \
{ \
auto err = call; \
if (XPU_SUCCESS != err) { \
fprintf(stderr, "KUNLUN error in %s:%i : %s.\n", __FILE__, \
__LINE__, xpu_strerror(err)); \
return INFINI_STATUS_INTERNAL_ERROR; \
} \
}
struct InfiniopKunlunHandle {
infiniDevice_t device;
int device_id;
std::shared_ptr<Pool<xdnnHandle_t>> xdnn_handle_pool;
};
template <typename T>
void use_xdnn(std::shared_ptr<Pool<xdnnHandle_t>> &pool, KunlunStream_t stream, const T &f) {
auto handle = pool->pop();
if (!handle) {
*handle = xdnn::create_context();
}
(*handle)->set_stream(stream);
f(*handle);
pool->push(std::move(*handle));
}
#endif //__INFINIOP_COMMON_KUNLUN_H__
#include "common_kunlun.h"
infiniStatus_t createKunlunHandle(infiniopKunlunHandle_t *handle_ptr) {
int device_id;
CHECK_KUNLUN(xpu_current_device(&device_id));
auto pool = std::make_shared<Pool<xdnnHandle_t>>();
xdnnHandle_t handle = xdnn::create_context();
pool->push(std::move(handle));
*handle_ptr = new InfiniopKunlunHandle{
INFINI_DEVICE_KUNLUN,
device_id,
std::move(pool),
};
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t destroyKunlunHandle(infiniopKunlunHandle_t handle_ptr) {
handle_ptr->xdnn_handle_pool = nullptr;
delete handle_ptr;
return INFINI_STATUS_SUCCESS;
}
#ifndef __INFINIOP_KUNLUN_HANDLE_H__
#define __INFINIOP_KUNLUN_HANDLE_H__
#include "infiniop/handle.h"
struct InfiniopKunlunHandle;
typedef struct InfiniopKunlunHandle *infiniopKunlunHandle_t;
infiniStatus_t createKunlunHandle(infiniopKunlunHandle_t *handle_ptr);
infiniStatus_t destroyKunlunHandle(infiniopKunlunHandle_t handle);
#endif // __INFINIOP_KUNLUN_HANDLE_H__
#include "matmul_kunlun.h"
#include "../../../devices/kunlun/common_kunlun.h"
#include "../../utils.h"
namespace matmul::kunlun {
struct Descriptor::Opaque {
std::shared_ptr<Pool<xdnnHandle_t>> xdnn_handle_pool;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
auto handle = reinterpret_cast<infiniopKunlunHandle_t>(handle_);
auto dtype = c_desc->dtype;
if (dtype != INFINI_DTYPE_F16 && dtype != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
infiniStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::ROW_MAJOR);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
*desc_ptr = new Descriptor(
dtype, info, 0,
new Opaque{handle->xdnn_handle_pool},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <class Tdata>
void calculate(
const MatmulInfo &info,
std::shared_ptr<Pool<xdnnHandle_t>> &xdnn_handle_pool,
infiniDtype_t dtype,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
KunlunStream_t stream) {
if (info.is_transed) {
std::swap(a, b);
}
auto transA = info.a_matrix.col_stride == 1 ? false : true;
auto transB = info.b_matrix.col_stride == 1 ? false : true;
auto unit = infiniSizeof(dtype);
use_xdnn(xdnn_handle_pool,
(KunlunStream_t)stream,
[&](xdnnHandle_t handle) {
for (size_t i = 0; i < info.batch; i++) {
xdnn::fc_fusion<Tdata, Tdata, Tdata, int16_t>(
handle,
(Tdata *)((char *)a + i * info.a_matrix.stride * unit),
(Tdata *)((char *)b + i * info.b_matrix.stride * unit),
(Tdata *)((char *)c + i * info.c_matrix.stride * unit),
info.m,
info.n,
info.k,
transA,
transB,
nullptr,
nullptr,
nullptr,
info.a_matrix.ld(),
info.b_matrix.ld(),
info.c_matrix.ld(),
alpha,
beta,
nullptr,
xdnn::Activation_t::LINEAR,
nullptr);
}
});
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t worksapce_size,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
void *stream) const {
switch (_dtype) {
case INFINI_DTYPE_F16:
kunlun::calculate<float16>(_info, _opaque->xdnn_handle_pool, _dtype, c, beta, a, b, alpha, (KunlunStream_t)stream);
return INFINI_STATUS_SUCCESS;
case INFINI_DTYPE_F32:
kunlun::calculate<float>(_info, _opaque->xdnn_handle_pool, _dtype, c, beta, a, b, alpha, (KunlunStream_t)stream);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
} // namespace matmul::kunlun
#ifndef __MATMUL_KUNLUN_H__
#define __MATMUL_KUNLUN_H__
#include "../matmul.h"
DESCRIPTOR(kunlun)
#endif // __MATMUL_KUNLUN_H__
......@@ -12,6 +12,9 @@
#ifdef ENABLE_ASCEND_API
#include "ascend/matmul_ascend.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/matmul_kunlun.h"
#endif
__C infiniStatus_t infiniopCreateMatmulDescriptor(
infiniopHandle_t handle,
......@@ -43,6 +46,9 @@ __C infiniStatus_t infiniopCreateMatmulDescriptor(
#ifdef ENABLE_ASCEND_API
CREATE(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -75,6 +81,9 @@ infiniopGetMatmulWorkspaceSize(
#ifdef ENABLE_ASCEND_API
GET(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -115,6 +124,9 @@ __C infiniStatus_t infiniopMatmul(
#ifdef ENABLE_ASCEND_API
CALCULATE(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -145,6 +157,9 @@ infiniopDestroyMatmulDescriptor(infiniopMatmulDescriptor_t desc) {
#ifdef ENABLE_ASCEND_API
DELETE(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
......@@ -166,6 +166,11 @@ def get_args():
action="store_true",
help="Run ASCEND NPU test",
)
parser.add_argument(
"--kunlun",
action="store_true",
help="Run KUNLUN XPU test",
)
return parser.parse_args()
......@@ -428,6 +433,9 @@ def get_test_devices(args):
torch.npu.set_device(0) # Ascend NPU needs explicit device initialization
devices_to_test.append(InfiniDeviceEnum.ASCEND)
if args.kunlun:
import torch_xmlir
devices_to_test.append(InfiniDeviceEnum.KUNLUN)
if not devices_to_test:
devices_to_test = [InfiniDeviceEnum.CPU]
......
......@@ -100,6 +100,17 @@ if has_config("sugon-dcu") then
add_defines("ENABLE_SUGON_CUDA_API")
end
-- 昆仑芯
option("kunlun-xpu")
set_default(false)
set_showmenu(true)
set_description("Enable or disable Kunlun XPU kernel")
option_end()
if has_config("kunlun-xpu") then
add_defines("ENABLE_KUNLUN_API")
includes("xmake/kunlun.lua")
end
target("infiniop")
set_kind("shared")
......@@ -134,6 +145,9 @@ target("infiniop")
if has_config("metax-gpu") then
add_deps("metax-gpu")
end
if has_config("kunlun-xpu") then
add_deps("infiniop-kunlun")
end
set_languages("cxx17")
add_files("src/infiniop/devices/handle.cc")
add_files("src/infiniop/ops/*/operator.cc")
......
add_defines("ENABLE_KUNLUN_API")
local KUNLUN_HOME = os.getenv("KUNLUN_HOME")
-- Add include dirs
add_includedirs(path.join(KUNLUN_HOME, "include"), {public=true})
add_linkdirs(path.join(KUNLUN_HOME, "lib64"))
add_links("xpurt")
add_links("xpuapi")
target("infiniop-kunlun")
-- Other configs
set_kind("static")
set_languages("cxx17")
on_install(function (target) end)
-- Add files
add_files("$(projectdir)/src/infiniop/devices/kunlun/*.cc", "$(projectdir)/src/infiniop/ops/*/kunlun/*.cc")
add_cxflags("-lstdc++ -Wall -Werror -fPIC")
target_end()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment