Commit 2d759447 authored by qinyiqun's avatar qinyiqun
Browse files

issue/87/musa: fix 使用公共的Pool和Check Macro,添加缺失的运行时接口

parent b96284ad
#include "../../../utils.h" #include "../../../utils.h"
#include "../pool.h"
#include "musa_handle.h" #include "musa_handle.h"
#include "pool.h"
#include <memory>
#include <mublas.h> #include <mublas.h>
#include <mudnn.h> #include <mudnn.h>
#include <musa.h> #include <musa.h>
// #include <musa_fp16.h>
#include <musa_fp16_mtgpu.h> #include <musa_fp16_mtgpu.h>
#include <musa_runtime_api.h> #include <musa_runtime_api.h>
#define CHECK_MUBLAS(API) CHECK_INTERNAL(API, MUBLAS_STATUS_SUCCESS) #define CHECK_MUBLAS(API) CHECK_INTERNAL(API, MUBLAS_STATUS_SUCCESS)
#define CHECK_MUDNN(API) CHECK_INTERNAL_MUDNN(API, ::musa::dnn::Status::SUCCESS, return INFINI_STATUS_INTERNAL_ERROR) #define CHECK_MUDNN(API) CHECK_INTERNAL((int)API, (int)::musa::dnn::Status::SUCCESS)
#define CHECK_INTERNAL_MUDNN(API, EXPECT, ACTION) \
do { \
auto api_result_ = (API); \
if (api_result_ != (EXPECT)) { \
std::cerr << "Error Code " << (int)api_result_ << " in `" << #API << "`" \
<< " from " << __func__ \
<< " at " << __FILE__ << ":" << __LINE__ << std::endl; \
{ ACTION; } \
} \
} while (0)
namespace device::musa { namespace device::musa {
class Handle::Internal { class Handle::Internal {
Pool<mublasHandle_t> mublas_handles; Pool<std::unique_ptr<mublasHandle_t>> mublas_handles;
Pool<::musa::dnn::Handle> mudnn_handles; Pool<std::unique_ptr<::musa::dnn::Handle>> mudnn_handles;
template <typename T> template <typename T>
using Fn = std::function<infiniStatus_t(T)>; using Fn = std::function<infiniStatus_t(T)>;
public: public:
infiniStatus_t useMublas(MUstream stream, const Fn<mublasHandle_t> &f) const; infiniStatus_t useMublas(musaStream_t stream, const Fn<mublasHandle_t> &f) const;
infiniStatus_t useMudnn(musaStream_t stream, const Fn<::musa::dnn::Handle &> &f) const; infiniStatus_t useMudnn(musaStream_t stream, const Fn<::musa::dnn::Handle &> &f) const;
}; };
......
...@@ -11,26 +11,32 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & { ...@@ -11,26 +11,32 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & {
return _internal; return _internal;
} }
infiniStatus_t Handle::Internal::useMublas(MUstream stream, const Fn<mublasHandle_t> &f) const { infiniStatus_t Handle::Internal::useMublas(musaStream_t stream, const Fn<mublasHandle_t> &f) const {
mublasHandle_t *handle = mublas_handles.pop(); std::unique_ptr<mublasHandle_t> handle;
if (!handle) { auto opt_handle = mublas_handles.pop();
handle = new mublasHandle_t; if (opt_handle.has_value()) {
CHECK_MUBLAS(mublasCreate(handle)); handle = std::move(*opt_handle);
} else {
handle = std::make_unique<mublasHandle_t>();
CHECK_MUBLAS(mublasCreate(&(*handle)));
} }
CHECK_MUBLAS(mublasSetStream(*handle, stream)); CHECK_MUBLAS(mublasSetStream(*handle, stream));
CHECK_STATUS(f(*handle)); CHECK_STATUS(f(*handle));
mublas_handles.push(handle); mublas_handles.push(std::move(handle));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t Handle::Internal::useMudnn(musaStream_t stream, const Fn<::musa::dnn::Handle &> &f) const { infiniStatus_t Handle::Internal::useMudnn(musaStream_t stream, const Fn<::musa::dnn::Handle &> &f) const {
::musa::dnn::Handle *handle = mudnn_handles.pop(); std::unique_ptr<::musa::dnn::Handle> handle;
if (!handle) { auto opt_handle = mudnn_handles.pop();
handle = new ::musa::dnn::Handle(); if (opt_handle.has_value()) {
handle = std::move(*opt_handle);
} else {
handle = std::make_unique<::musa::dnn::Handle>();
} }
CHECK_MUDNN(handle->SetStream(stream)); CHECK_MUDNN(handle->SetStream(stream));
CHECK_STATUS(f(*handle)); CHECK_STATUS(f(*handle));
mudnn_handles.push(handle); mudnn_handles.push(std::move(handle));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
......
#ifndef __POOL_MUSA_H__
#define __POOL_MUSA_H__
#include <atomic>
#include <mutex>
#include <optional>
template <class T>
class Pool {
public:
Pool() : _head(nullptr) {}
Pool(const Pool &) = delete;
Pool(Pool &&pool) noexcept : _head(pool._head.exchange(nullptr)) {}
~Pool() {
while (this->pop()) {}
}
void push(T *val) const {
Node<T> *new_node = new Node<T>(val);
new_node->next = _head.load();
while (!_head.compare_exchange_weak(new_node->next, new_node)) {}
}
T *pop() const {
Node<T> *top = _head.load();
Node<T> *new_head = nullptr;
do {
if (!top) {
return nullptr;
}
new_head = top->next;
} while (!_head.compare_exchange_weak(top, new_head));
return top->data;
}
private:
template <class U>
struct Node {
U *data;
Node<U> *next;
Node(U *data) : data(data), next(nullptr) {}
};
mutable std::atomic<Node<T> *> _head;
};
#endif // __POOL_MUSA_H__
...@@ -41,7 +41,7 @@ private: ...@@ -41,7 +41,7 @@ private:
struct Node { struct Node {
U data; U data;
Node<U> *next; Node<U> *next;
Node(U &&data) : data(data), next(nullptr) {} Node(U &&data) : data(std::move(data)), next(nullptr) {}
}; };
mutable std::atomic<Node<T> *> _head; mutable std::atomic<Node<T> *> _head;
......
...@@ -21,9 +21,7 @@ infiniStatus_t Descriptor::create( ...@@ -21,9 +21,7 @@ infiniStatus_t Descriptor::create(
auto handle = reinterpret_cast<device::musa::Handle *>(handle_); auto handle = reinterpret_cast<device::musa::Handle *>(handle_);
auto dtype = c_desc->dtype(); auto dtype = c_desc->dtype();
if (dtype != INFINI_DTYPE_F16 && dtype != INFINI_DTYPE_F32) { CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32);
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
infiniStatus_t status; infiniStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::COL_MAJOR); auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::COL_MAJOR);
...@@ -73,7 +71,7 @@ infiniStatus_t calculate( ...@@ -73,7 +71,7 @@ infiniStatus_t calculate(
auto op_b = info.b_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T; auto op_b = info.b_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T;
CHECK_STATUS(_internal->useMublas( CHECK_STATUS(_internal->useMublas(
(MUstream)stream, (musaStream_t)stream,
[&](mublasHandle_t handle) { [&](mublasHandle_t handle) {
CHECK_MUBLAS( CHECK_MUBLAS(
mublasGemmStridedBatchedEx( mublasGemmStridedBatchedEx(
...@@ -114,28 +112,12 @@ infiniStatus_t Descriptor::calculate(void *workspace, ...@@ -114,28 +112,12 @@ infiniStatus_t Descriptor::calculate(void *workspace,
float alpha, float alpha,
void *stream) const { void *stream) const {
switch (_dtype) { switch (_dtype) {
case INFINI_DTYPE_F16: case INFINI_DTYPE_F16:
return musa::calculate<half>(_info, return musa::calculate<half>(_info, _opaque->internal, c, beta, a, b, alpha, stream);
_opaque->internal, case INFINI_DTYPE_F32:
c, return musa::calculate<float>(_info,_opaque->internal, c, beta, a, b, alpha, stream);
beta, default:
a, return INFINI_STATUS_BAD_TENSOR_DTYPE;
b,
alpha,
stream);
case INFINI_DTYPE_F32:
return musa::calculate<float>(_info,
_opaque->internal,
c,
beta,
a,
b,
alpha,
stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
} }
......
...@@ -39,7 +39,8 @@ infiniStatus_t streamSynchronize(infinirtStream_t stream) { ...@@ -39,7 +39,8 @@ infiniStatus_t streamSynchronize(infinirtStream_t stream) {
} }
infiniStatus_t streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) { infiniStatus_t streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) {
return INFINI_STATUS_NOT_IMPLEMENTED; CHECK_MUSART(musaStreamWaitEvent((musaStream_t)stream, (musaEvent_t)event));
return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) { infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) {
......
local MUSA_HOME = os.getenv("MUSA_INSTALL_PATH") local MUSA_HOME = os.getenv("MUSA_INSTALL_PATH")
add_includedirs(MUSA_HOME .. "/include") add_includedirs(MUSA_HOME .. "/include")
add_linkdirs(MUSA_HOME .. "/lib") add_linkdirs(MUSA_HOME .. "/lib")
add_links("libmusa.so") add_links("musa", "musart", "mudnn", "mublas")
add_links("libmusart.so")
add_links("libmudnn.so")
add_links("libmublas.so")
rule("mu") rule("mu")
set_extensions(".mu") set_extensions(".mu")
...@@ -32,12 +29,10 @@ target("infiniop-moore") ...@@ -32,12 +29,10 @@ target("infiniop-moore")
set_kind("static") set_kind("static")
on_install(function (target) end) on_install(function (target) end)
set_languages("cxx17") set_languages("cxx17")
set_warnings("all") set_warnings("all", "error")
add_cxflags("-lstdc++ -Wall -fPIC") add_cxflags("-lstdc++", "-fPIC", "-Wno-comment")
add_files("../src/infiniop/devices/musa/*.cc", "../src/infiniop/ops/*/musa/*.cc") add_files("../src/infiniop/devices/musa/*.cc", "../src/infiniop/ops/*/musa/*.cc")
add_files("../src/infiniop/ops/*/musa/*.mu", {rule = "mu"}) add_files("../src/infiniop/ops/*/musa/*.mu", {rule = "mu"})
add_cxflags("-lstdc++ -Wall -fPIC")
target_end() target_end()
target("infinirt-moore") target("infinirt-moore")
...@@ -45,7 +40,7 @@ target("infinirt-moore") ...@@ -45,7 +40,7 @@ target("infinirt-moore")
set_languages("cxx17") set_languages("cxx17")
on_install(function (target) end) on_install(function (target) end)
add_deps("infini-utils") add_deps("infini-utils")
-- Add files set_warnings("all", "error")
add_files("$(projectdir)/src/infinirt/musa/*.cc") add_cxflags("-lstdc++", "-fPIC")
add_cxflags("-lstdc++ -Wall -Werror -fPIC") add_files("../src/infinirt/musa/*.cc")
target_end() target_end()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment