Commit 2d759447 authored by qinyiqun's avatar qinyiqun
Browse files

issue/87/musa: fix 使用公共的Pool和Check Macro,添加缺失的运行时接口

parent b96284ad
#include "../../../utils.h"
#include "../pool.h"
#include "musa_handle.h"
#include "pool.h"
#include <memory>
#include <mublas.h>
#include <mudnn.h>
#include <musa.h>
// #include <musa_fp16.h>
#include <musa_fp16_mtgpu.h>
#include <musa_runtime_api.h>
#define CHECK_MUBLAS(API) CHECK_INTERNAL(API, MUBLAS_STATUS_SUCCESS)
#define CHECK_MUDNN(API) CHECK_INTERNAL_MUDNN(API, ::musa::dnn::Status::SUCCESS, return INFINI_STATUS_INTERNAL_ERROR)
#define CHECK_INTERNAL_MUDNN(API, EXPECT, ACTION) \
do { \
auto api_result_ = (API); \
if (api_result_ != (EXPECT)) { \
std::cerr << "Error Code " << (int)api_result_ << " in `" << #API << "`" \
<< " from " << __func__ \
<< " at " << __FILE__ << ":" << __LINE__ << std::endl; \
{ ACTION; } \
} \
} while (0)
#define CHECK_MUDNN(API) CHECK_INTERNAL((int)API, (int)::musa::dnn::Status::SUCCESS)
namespace device::musa {
class Handle::Internal {
Pool<mublasHandle_t> mublas_handles;
Pool<::musa::dnn::Handle> mudnn_handles;
Pool<std::unique_ptr<mublasHandle_t>> mublas_handles;
Pool<std::unique_ptr<::musa::dnn::Handle>> mudnn_handles;
template <typename T>
using Fn = std::function<infiniStatus_t(T)>;
public:
infiniStatus_t useMublas(MUstream stream, const Fn<mublasHandle_t> &f) const;
infiniStatus_t useMublas(musaStream_t stream, const Fn<mublasHandle_t> &f) const;
infiniStatus_t useMudnn(musaStream_t stream, const Fn<::musa::dnn::Handle &> &f) const;
};
......
......@@ -11,26 +11,32 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & {
return _internal;
}
infiniStatus_t Handle::Internal::useMublas(MUstream stream, const Fn<mublasHandle_t> &f) const {
mublasHandle_t *handle = mublas_handles.pop();
if (!handle) {
handle = new mublasHandle_t;
CHECK_MUBLAS(mublasCreate(handle));
infiniStatus_t Handle::Internal::useMublas(musaStream_t stream, const Fn<mublasHandle_t> &f) const {
std::unique_ptr<mublasHandle_t> handle;
auto opt_handle = mublas_handles.pop();
if (opt_handle.has_value()) {
handle = std::move(*opt_handle);
} else {
handle = std::make_unique<mublasHandle_t>();
CHECK_MUBLAS(mublasCreate(&(*handle)));
}
CHECK_MUBLAS(mublasSetStream(*handle, stream));
CHECK_STATUS(f(*handle));
mublas_handles.push(handle);
mublas_handles.push(std::move(handle));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Handle::Internal::useMudnn(musaStream_t stream, const Fn<::musa::dnn::Handle &> &f) const {
::musa::dnn::Handle *handle = mudnn_handles.pop();
if (!handle) {
handle = new ::musa::dnn::Handle();
std::unique_ptr<::musa::dnn::Handle> handle;
auto opt_handle = mudnn_handles.pop();
if (opt_handle.has_value()) {
handle = std::move(*opt_handle);
} else {
handle = std::make_unique<::musa::dnn::Handle>();
}
CHECK_MUDNN(handle->SetStream(stream));
CHECK_STATUS(f(*handle));
mudnn_handles.push(handle);
mudnn_handles.push(std::move(handle));
return INFINI_STATUS_SUCCESS;
}
......
#ifndef __POOL_MUSA_H__
#define __POOL_MUSA_H__
#include <atomic>
#include <mutex>
#include <optional>
template <class T>
class Pool {
public:
Pool() : _head(nullptr) {}
Pool(const Pool &) = delete;
Pool(Pool &&pool) noexcept : _head(pool._head.exchange(nullptr)) {}
~Pool() {
while (this->pop()) {}
}
void push(T *val) const {
Node<T> *new_node = new Node<T>(val);
new_node->next = _head.load();
while (!_head.compare_exchange_weak(new_node->next, new_node)) {}
}
T *pop() const {
Node<T> *top = _head.load();
Node<T> *new_head = nullptr;
do {
if (!top) {
return nullptr;
}
new_head = top->next;
} while (!_head.compare_exchange_weak(top, new_head));
return top->data;
}
private:
template <class U>
struct Node {
U *data;
Node<U> *next;
Node(U *data) : data(data), next(nullptr) {}
};
mutable std::atomic<Node<T> *> _head;
};
#endif // __POOL_MUSA_H__
......@@ -41,7 +41,7 @@ private:
struct Node {
U data;
Node<U> *next;
Node(U &&data) : data(data), next(nullptr) {}
Node(U &&data) : data(std::move(data)), next(nullptr) {}
};
mutable std::atomic<Node<T> *> _head;
......
......@@ -21,9 +21,7 @@ infiniStatus_t Descriptor::create(
auto handle = reinterpret_cast<device::musa::Handle *>(handle_);
auto dtype = c_desc->dtype();
if (dtype != INFINI_DTYPE_F16 && dtype != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32);
infiniStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::COL_MAJOR);
......@@ -73,7 +71,7 @@ infiniStatus_t calculate(
auto op_b = info.b_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T;
CHECK_STATUS(_internal->useMublas(
(MUstream)stream,
(musaStream_t)stream,
[&](mublasHandle_t handle) {
CHECK_MUBLAS(
mublasGemmStridedBatchedEx(
......@@ -115,25 +113,9 @@ infiniStatus_t Descriptor::calculate(void *workspace,
void *stream) const {
switch (_dtype) {
case INFINI_DTYPE_F16:
return musa::calculate<half>(_info,
_opaque->internal,
c,
beta,
a,
b,
alpha,
stream);
return musa::calculate<half>(_info, _opaque->internal, c, beta, a, b, alpha, stream);
case INFINI_DTYPE_F32:
return musa::calculate<float>(_info,
_opaque->internal,
c,
beta,
a,
b,
alpha,
stream);
return musa::calculate<float>(_info,_opaque->internal, c, beta, a, b, alpha, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......
......@@ -39,7 +39,8 @@ infiniStatus_t streamSynchronize(infinirtStream_t stream) {
}
infiniStatus_t streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) {
return INFINI_STATUS_NOT_IMPLEMENTED;
CHECK_MUSART(musaStreamWaitEvent((musaStream_t)stream, (musaEvent_t)event));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) {
......
local MUSA_HOME = os.getenv("MUSA_INSTALL_PATH")
add_includedirs(MUSA_HOME .. "/include")
add_linkdirs(MUSA_HOME .. "/lib")
add_links("libmusa.so")
add_links("libmusart.so")
add_links("libmudnn.so")
add_links("libmublas.so")
add_links("musa", "musart", "mudnn", "mublas")
rule("mu")
set_extensions(".mu")
......@@ -32,12 +29,10 @@ target("infiniop-moore")
set_kind("static")
on_install(function (target) end)
set_languages("cxx17")
set_warnings("all")
add_cxflags("-lstdc++ -Wall -fPIC")
set_warnings("all", "error")
add_cxflags("-lstdc++", "-fPIC", "-Wno-comment")
add_files("../src/infiniop/devices/musa/*.cc", "../src/infiniop/ops/*/musa/*.cc")
add_files("../src/infiniop/ops/*/musa/*.mu", {rule = "mu"})
add_cxflags("-lstdc++ -Wall -fPIC")
target_end()
target("infinirt-moore")
......@@ -45,7 +40,7 @@ target("infinirt-moore")
set_languages("cxx17")
on_install(function (target) end)
add_deps("infini-utils")
-- Add files
add_files("$(projectdir)/src/infinirt/musa/*.cc")
add_cxflags("-lstdc++ -Wall -Werror -fPIC")
set_warnings("all", "error")
add_cxflags("-lstdc++", "-fPIC")
add_files("../src/infinirt/musa/*.cc")
target_end()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment