Commit 29ac8dcd authored by zhangyue's avatar zhangyue
Browse files

issue/87:fix useXdnn func

parent f4593f17
......@@ -10,17 +10,17 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & {
return _internal;
}
template <typename T>
using Fn = std::function<void(T)>;
void Handle::Internal::use_xdnn(kunlunStream_t stream, const Fn<xdnnHandle_t> &f) const {
infiniStatus_t Handle::Internal::useXdnn(kunlunStream_t stream, const Fn<xdnnHandle_t> &f) const {
auto handle = dnn_handles.pop();
if (!handle) {
*handle = xdnn::create_context();
}
(*handle)->set_stream(stream);
f(*handle);
CHECK_STATUS(f(*handle));
dnn_handles.push(std::move(*handle));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
......
#ifndef __INFINIOP_KUNLUN_HANDLE_H__
#define __INFINIOP_KUNLUN_HANDLE_H__
#include "../../../utils.h"
#include "../../handle.h"
#include "../pool.h"
#include <functional>
......@@ -15,6 +16,8 @@ typedef XPUStream kunlunStream_t;
typedef XPUEvent kunlunEvent_t;
typedef xdnn::Context *xdnnHandle_t;
#define CHECK_XDNN(API) CHECK_INTERNAL(API, XPU_SUCCESS)
namespace device::kunlun {
struct Handle : public InfiniopHandle {
......@@ -32,9 +35,11 @@ public:
class Handle::Internal {
Pool<xdnnHandle_t> dnn_handles;
template <typename T>
using Fn = std::function<infiniStatus_t(T)>;
public:
void use_xdnn(kunlunStream_t stream, const std::function<void(xdnnHandle_t)> &f) const;
infiniStatus_t useXdnn(kunlunStream_t stream, const Fn<xdnnHandle_t> &f) const;
};
} // namespace device::kunlun
......
......@@ -41,7 +41,7 @@ infiniStatus_t Descriptor::create(
}
template <class Tdata>
void calculate(
infiniStatus_t calculate(
MatmulInfo info,
std::shared_ptr<HandleInternal> internal,
infiniDtype_t dtype,
......@@ -61,11 +61,11 @@ void calculate(
auto unit = infiniSizeOf(dtype);
internal->use_xdnn(
return internal->useXdnn(
(kunlunStream_t)stream,
[&](xdnnHandle_t handle) {
for (size_t i = 0; i < info.batch; i++) {
xdnn::fc_fusion<Tdata, Tdata, Tdata, int16_t>(
CHECK_XDNN((xdnn::fc_fusion<Tdata, Tdata, Tdata, int16_t>(
handle,
(Tdata *)((char *)a + i * info.a_matrix.stride * unit),
(Tdata *)((char *)b + i * info.b_matrix.stride * unit),
......@@ -85,8 +85,9 @@ void calculate(
beta,
nullptr,
xdnn::Activation_t::LINEAR,
nullptr);
nullptr)));
}
return INFINI_STATUS_SUCCESS;
});
}
......@@ -101,13 +102,9 @@ infiniStatus_t Descriptor::calculate(
void *stream) const {
switch (_dtype) {
case INFINI_DTYPE_F16:
op::matmul::kunlun::calculate<float16>(_info, _opaque->internal, _dtype, c, beta, a, b, alpha, (kunlunStream_t)stream);
return INFINI_STATUS_SUCCESS;
return op::matmul::kunlun::calculate<float16>(_info, _opaque->internal, _dtype, c, beta, a, b, alpha, (kunlunStream_t)stream);
case INFINI_DTYPE_F32:
op::matmul::kunlun::calculate<float>(_info, _opaque->internal, _dtype, c, beta, a, b, alpha, (kunlunStream_t)stream);
return INFINI_STATUS_SUCCESS;
return op::matmul::kunlun::calculate<float>(_info, _opaque->internal, _dtype, c, beta, a, b, alpha, (kunlunStream_t)stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......
......@@ -10,4 +10,4 @@ INFINIRT_DEVICE_API_NOOP
#endif
} // namespace infinirt::kunlun
#endif // __INFINIRT_KUNLUN_H__
\ No newline at end of file
#endif // __INFINIRT_KUNLUN_H__
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment