Unverified Commit fb83addc authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #102 from InfiniTensor/issue/87/musa

issue/87:增加摩尔线程平台的handle
parents dc52ac84 2d759447
......@@ -14,6 +14,9 @@
#ifdef ENABLE_ASCEND_API
#include "ascend/ascend_handle.h"
#endif
#ifdef ENABLE_MOORE_API
#include "musa/musa_handle.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/kunlun_handle.h"
#endif
......@@ -47,6 +50,9 @@ __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) {
#ifdef ENABLE_ASCEND_API
CREATE(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, musa);
#endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
......@@ -81,6 +87,9 @@ __C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) {
#ifdef ENABLE_ASCEND_API
DELETE(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_MOORE_API
DELETE(INFINI_DEVICE_MOORE, musa);
#endif
#ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
......
#include "../../../utils.h"
#include "../pool.h"
#include "musa_handle.h"
#include <mublas.h>
#include <mudnn.h>
#include <musa.h>
#include <musa_fp16_mtgpu.h>
#include <musa_runtime_api.h>
#define CHECK_MUBLAS(API) CHECK_INTERNAL(API, MUBLAS_STATUS_SUCCESS)
#define CHECK_MUDNN(API) CHECK_INTERNAL((int)API, (int)::musa::dnn::Status::SUCCESS)
namespace device::musa {
class Handle::Internal {
Pool<std::unique_ptr<mublasHandle_t>> mublas_handles;
Pool<std::unique_ptr<::musa::dnn::Handle>> mudnn_handles;
template <typename T>
using Fn = std::function<infiniStatus_t(T)>;
public:
infiniStatus_t useMublas(musaStream_t stream, const Fn<mublasHandle_t> &f) const;
infiniStatus_t useMudnn(musaStream_t stream, const Fn<::musa::dnn::Handle &> &f) const;
};
} // namespace device::musa
#include "common_musa.h"
namespace device::musa {
Handle::Handle(infiniDevice_t device, int device_id)
: InfiniopHandle{device, device_id},
_internal(std::make_shared<Handle::Internal>()) {}
Handle::Handle(int device_id) : Handle(INFINI_DEVICE_MOORE, device_id) {}
auto Handle::internal() const -> const std::shared_ptr<Internal> & {
return _internal;
}
infiniStatus_t Handle::Internal::useMublas(musaStream_t stream, const Fn<mublasHandle_t> &f) const {
std::unique_ptr<mublasHandle_t> handle;
auto opt_handle = mublas_handles.pop();
if (opt_handle.has_value()) {
handle = std::move(*opt_handle);
} else {
handle = std::make_unique<mublasHandle_t>();
CHECK_MUBLAS(mublasCreate(&(*handle)));
}
CHECK_MUBLAS(mublasSetStream(*handle, stream));
CHECK_STATUS(f(*handle));
mublas_handles.push(std::move(handle));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Handle::Internal::useMudnn(musaStream_t stream, const Fn<::musa::dnn::Handle &> &f) const {
std::unique_ptr<::musa::dnn::Handle> handle;
auto opt_handle = mudnn_handles.pop();
if (opt_handle.has_value()) {
handle = std::move(*opt_handle);
} else {
handle = std::make_unique<::musa::dnn::Handle>();
}
CHECK_MUDNN(handle->SetStream(stream));
CHECK_STATUS(f(*handle));
mudnn_handles.push(std::move(handle));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
*handle_ptr = new Handle(INFINI_DEVICE_MOORE, device_id);
return INFINI_STATUS_SUCCESS;
}
} // namespace device::musa
#ifndef __INFINIOP_MUSA_HANDLE_H__
#define __INFINIOP_MUSA_HANDLE_H__
#include "../../handle.h"
#include <memory>
namespace device::musa {
struct Handle : public InfiniopHandle {
Handle(int device_id);
class Internal;
auto internal() const -> const std::shared_ptr<Internal> &;
public:
static infiniStatus_t create(InfiniopHandle **handle_ptr, int device_id);
protected:
Handle(infiniDevice_t device, int device_id);
private:
std::shared_ptr<Internal> _internal;
};
} // namespace device::musa
#endif // __INFINIOP_MUSA_HANDLE_H__
......@@ -41,7 +41,7 @@ private:
struct Node {
U data;
Node<U> *next;
Node(U &&data) : data(data), next(nullptr) {}
Node(U &&data) : data(std::move(data)), next(nullptr) {}
};
mutable std::atomic<Node<T> *> _head;
......
#ifndef __GEMM_MUSA_H__
#define __GEMM_MUSA_H__
#include "../gemm.h"
DESCRIPTOR(musa)
#endif // __GEMM_MUSA_H__
#include "../../../devices/musa/common_musa.h"
#include "../../../devices/musa/musa_handle.h"
#include "gemm_musa.h"
namespace op::gemm::musa {
struct Descriptor::Opaque {
std::shared_ptr<device::musa::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
auto handle = reinterpret_cast<device::musa::Handle *>(handle_);
auto dtype = c_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32);
infiniStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::COL_MAJOR);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
*desc_ptr = new Descriptor(
dtype, info, 0,
new Opaque{handle->internal()},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename Tdata>
infiniStatus_t calculate(
const MatmulInfo &info,
std::shared_ptr<device::musa::Handle::Internal> &_internal,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
void *stream) {
musaDataType a_type, b_type, c_type;
mublasComputeType_t compute_type;
Tdata alpha_, beta_;
if constexpr (std::is_same<Tdata, half>::value) {
alpha_ = __float2half(alpha);
beta_ = __float2half(beta);
a_type = b_type = c_type = MUSA_R_16F;
compute_type = MUBLAS_COMPUTE_16F;
} else {
alpha_ = alpha;
beta_ = beta;
a_type = b_type = c_type = MUSA_R_32F;
compute_type = MUBLAS_COMPUTE_32F_FAST_TF32;
}
if (info.is_transed) {
std::swap(a, b);
}
auto op_a = info.a_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T;
auto op_b = info.b_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T;
CHECK_STATUS(_internal->useMublas(
(musaStream_t)stream,
[&](mublasHandle_t handle) {
CHECK_MUBLAS(
mublasGemmStridedBatchedEx(
handle,
op_a,
op_b,
static_cast<int>(info.m),
static_cast<int>(info.n),
static_cast<int>(info.k),
&alpha_,
a,
a_type,
static_cast<int>(info.a_matrix.ld()),
info.a_matrix.stride,
b,
b_type,
static_cast<int>(info.b_matrix.ld()),
info.b_matrix.stride,
&beta_,
c,
c_type,
static_cast<int>(info.c_matrix.ld()),
info.c_matrix.stride,
static_cast<int>(info.batch),
compute_type,
MUBLAS_GEMM_DEFAULT));
return INFINI_STATUS_SUCCESS;
}));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace,
size_t workspace_size,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
void *stream) const {
switch (_dtype) {
case INFINI_DTYPE_F16:
return musa::calculate<half>(_info, _opaque->internal, c, beta, a, b, alpha, stream);
case INFINI_DTYPE_F32:
return musa::calculate<float>(_info,_opaque->internal, c, beta, a, b, alpha, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
} // namespace op::gemm::musa
......@@ -17,6 +17,9 @@
#ifdef ENABLE_METAX_API
#include "maca/gemm_maca.h"
#endif
#ifdef ENABLE_MOORE_API
#include "musa/gemm_musa.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/gemm_kunlun.h"
#endif
......@@ -54,6 +57,10 @@ __C infiniStatus_t infiniopCreateGemmDescriptor(
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, maca);
#endif
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, musa);
#endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
......@@ -92,6 +99,9 @@ infiniopGetGemmWorkspaceSize(
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, maca);
#endif
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, musa);
#endif
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
......@@ -138,6 +148,9 @@ __C infiniStatus_t infiniopGemm(
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, maca);
#endif
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, musa);
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
......@@ -174,6 +187,9 @@ infiniopDestroyGemmDescriptor(infiniopGemmDescriptor_t desc) {
#ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, maca);
#endif
#ifdef ENABLE_MOORE_API
DELETE(INFINI_DEVICE_MOORE, musa);
#endif
#ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
......
#include "infinirt_musa.h"
#include "../../utils.h"
#include <musa_runtime.h>
#include <musa_runtime_api.h>
#define CHECK_MUSART(RT_API) CHECK_INTERNAL(RT_API, musaSuccess)
namespace infinirt::musa {
infiniStatus_t getDeviceCount(int *count) {
CHECK_MUSART(musaGetDeviceCount(count));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t setDevice(int device_id) {
CHECK_MUSART(musaSetDevice(device_id));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t deviceSynchronize() {
CHECK_MUSART(musaDeviceSynchronize());
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t streamCreate(infinirtStream_t *stream_ptr) {
musaStream_t stream;
CHECK_MUSART(musaStreamCreate(&stream));
*stream_ptr = stream;
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t streamDestroy(infinirtStream_t stream) {
CHECK_MUSART(musaStreamDestroy((musaStream_t)stream));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t streamSynchronize(infinirtStream_t stream) {
CHECK_MUSART(musaStreamSynchronize((musaStream_t)stream));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) {
CHECK_MUSART(musaStreamWaitEvent((musaStream_t)stream, (musaEvent_t)event));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) {
musaEvent_t event;
CHECK_MUSART(musaEventCreate(&event));
*event_ptr = event;
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t eventRecord(infinirtEvent_t event, infinirtStream_t stream) {
CHECK_MUSART(musaEventRecord((musaEvent_t)event, (musaStream_t)stream));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t eventQuery(infinirtEvent_t event, infinirtEventStatus_t *status_ptr) {
auto status = musaEventQuery((musaEvent_t)event);
if (status == musaSuccess) {
*status_ptr = INFINIRT_EVENT_COMPLETE;
} else if (status == musaErrorNotReady) {
*status_ptr = INFINIRT_EVENT_NOT_READY;
} else {
CHECK_MUSART(status);
}
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t eventSynchronize(infinirtEvent_t event) {
CHECK_MUSART(musaEventSynchronize((musaEvent_t)event));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t eventDestroy(infinirtEvent_t event) {
CHECK_MUSART(musaEventDestroy((musaEvent_t)event));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t mallocDevice(void **p_ptr, size_t size) {
CHECK_MUSART(musaMalloc(p_ptr, size));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t mallocHost(void **p_ptr, size_t size) {
CHECK_MUSART(musaMallocHost(p_ptr, size));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t freeDevice(void *ptr) {
CHECK_MUSART(musaFree(ptr));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t freeHost(void *ptr) {
CHECK_MUSART(musaFreeHost(ptr));
return INFINI_STATUS_SUCCESS;
}
musaMemcpyKind toMusaMemcpyKind(infinirtMemcpyKind_t kind) {
switch (kind) {
case INFINIRT_MEMCPY_H2D:
return musaMemcpyHostToDevice;
case INFINIRT_MEMCPY_D2H:
return musaMemcpyDeviceToHost;
case INFINIRT_MEMCPY_D2D:
return musaMemcpyDeviceToDevice;
case INFINIRT_MEMCPY_H2H:
return musaMemcpyHostToHost;
default:
return musaMemcpyDefault;
}
}
infiniStatus_t memcpy(void *dst, const void *src, size_t size, infinirtMemcpyKind_t kind) {
CHECK_MUSART(musaMemcpy(dst, src, size, toMusaMemcpyKind(kind)));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t memcpyAsync(void *dst, const void *src, size_t size, infinirtMemcpyKind_t kind, infinirtStream_t stream) {
CHECK_MUSART(musaMemcpyAsync(dst, src, size, toMusaMemcpyKind(kind), (musaStream_t)stream));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t mallocAsync(void **p_ptr, size_t size, infinirtStream_t stream) {
return mallocDevice(p_ptr, size);
}
infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) {
return freeDevice(ptr);
}
} // namespace infinirt::musa
#ifndef __INFINIRT_MUSA_H__
#define __INFINIRT_MUSA_H__
#include "../infinirt_impl.h"
namespace infinirt::musa {
#ifdef ENABLE_MOORE_API
INFINIRT_DEVICE_API_IMPL
#else
INFINIRT_DEVICE_API_NOOP
#endif
} // namespace infinirt::musa
#endif // __INFINIRT_MUSA_H__
......@@ -171,6 +171,11 @@ def get_args():
action="store_true",
help="Run METAX GPU test",
)
parser.add_argument(
"--moore",
action="store_true",
help="Run MTHREADS GPU test",
)
parser.add_argument(
"--kunlun",
action="store_true",
......@@ -443,6 +448,11 @@ def get_test_devices(args):
import torch
devices_to_test.append(InfiniDeviceEnum.METAX)
if args.moore:
import torch
import torch_musa
devices_to_test.append(InfiniDeviceEnum.MOORE)
if args.kunlun:
import torch_xmlir
......
......@@ -89,7 +89,8 @@ option("moore-gpu")
option_end()
if has_config("moore-gpu") then
add_defines("ENABLE_MUSA_API")
add_defines("ENABLE_MOORE_API")
includes("xmake/musa.lua")
end
-- 海光
......@@ -154,6 +155,9 @@ target("infinirt")
if has_config("metax-gpu") then
add_deps("infinirt-metax")
end
if has_config("moore-gpu") then
add_deps("infinirt-moore")
end
if has_config("kunlun-xpu") then
add_deps("infinirt-kunlun")
end
......@@ -197,6 +201,9 @@ target("infiniop")
if has_config("metax-gpu") then
add_deps("infiniop-metax")
end
if has_config("moore-gpu") then
add_deps("infiniop-moore")
end
if has_config("kunlun-xpu") then
add_deps("infiniop-kunlun")
end
......
local MUSA_HOME = os.getenv("MUSA_INSTALL_PATH")
add_includedirs(MUSA_HOME .. "/include")
add_linkdirs(MUSA_HOME .. "/lib")
add_links("musa", "musart", "mudnn", "mublas")
rule("mu")
set_extensions(".mu")
on_load(function (target)
target:add("includedirs", "include")
end)
on_build_file(function (target, sourcefile)
local objectfile = target:objectfile(sourcefile)
os.mkdir(path.directory(objectfile))
local mcc = MUSA_HOME .. "/bin/mcc"
local includedirs = table.concat(target:get("includedirs"), " ")
local args = {"-c", sourcefile, "-o", objectfile, "-I" .. MUSA_HOME .. "/include", "-O3", "-fPIC", "-Wall", "-std=c++17", "-pthread"}
for _, includedir in ipairs(target:get("includedirs")) do
table.insert(args, "-I" .. includedir)
end
os.execv(mcc, args)
table.insert(target:objectfiles(), objectfile)
end)
rule_end()
target("infiniop-moore")
set_kind("static")
on_install(function (target) end)
set_languages("cxx17")
set_warnings("all", "error")
add_cxflags("-lstdc++", "-fPIC", "-Wno-comment")
add_files("../src/infiniop/devices/musa/*.cc", "../src/infiniop/ops/*/musa/*.cc")
add_files("../src/infiniop/ops/*/musa/*.mu", {rule = "mu"})
target_end()
target("infinirt-moore")
set_kind("static")
set_languages("cxx17")
on_install(function (target) end)
add_deps("infini-utils")
set_warnings("all", "error")
add_cxflags("-lstdc++", "-fPIC")
add_files("../src/infinirt/musa/*.cc")
target_end()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment