Commit e14dd2af authored by Zimin Li's avatar Zimin Li
Browse files

issue/89 add error checking, change Async malloc and free to use blocking versions, etc.

parent e9ce6db5
...@@ -6,14 +6,18 @@ ...@@ -6,14 +6,18 @@
#include "cnnl.h" #include "cnnl.h"
#include "cnrt.h" #include "cnrt.h"
#include <functional> #include <functional>
#include "../../../utils.h"
namespace device::bang { namespace device::bang {
class Handle::Internal { class Handle::Internal {
Pool<cnnlHandle_t> cnnl_handles; Pool<cnnlHandle_t> cnnl_handles;
template <typename T>
using Fn = std::function<infiniStatus_t(T)>;
public: public:
infiniStatus_t use_cnnl(cnrtQueue_t queue, const std::function<void(cnnlHandle_t)> &f) const; infiniStatus_t use_cnnl(cnrtQueue_t queue, const Fn<cnnlHandle_t> &f) const;
}; };
cnnlDataType_t getCnnlDtype(infiniDtype_t dt); cnnlDataType_t getCnnlDtype(infiniDtype_t dt);
......
...@@ -17,16 +17,13 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & { ...@@ -17,16 +17,13 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & {
return _internal; return _internal;
} }
template <typename T> infiniStatus_t Handle::Internal::use_cnnl(cnrtQueue_t queue, const Fn<cnnlHandle_t> &f) const {
using Fn = std::function<void(T)>;
infiniStatus_t Handle::Internal::use_cnnl(cnrtQueue_t queue, const std::function<void(cnnlHandle_t)> &f) const {
auto handle = cnnl_handles.pop(); auto handle = cnnl_handles.pop();
if (!handle) { if (!handle) {
cnnlCreate(&(*handle)); cnnlCreate(&(*handle));
} }
CHECK_BANG(cnnlSetQueue(*handle, queue)); CHECK_BANG(cnnlSetQueue(*handle, queue));
f(*handle); CHECK_STATUS(f(*handle));
cnnl_handles.push(std::move(*handle)); cnnl_handles.push(std::move(*handle));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
......
#include "matmul_bang.h" #include "matmul_bang.h"
#include "../../../devices/bang/bang_handle.h" #include "../../../devices/bang/bang_handle.h"
#include "../../../devices/bang/common_bang.h"
#include "../../../devices/bang/_internal.h" #include "../../../devices/bang/_internal.h"
#include <cnnl_extra.h> #include <cnnl_extra.h>
...@@ -100,16 +101,19 @@ infiniStatus_t Descriptor::create( ...@@ -100,16 +101,19 @@ infiniStatus_t Descriptor::create(
sizeof(int32_t)); sizeof(int32_t));
int count = 0; int count = 0;
handle->internal()->use_cnnl((cnrtQueue_t)nullptr, CHECK_STATUS(handle->internal()->use_cnnl((cnrtQueue_t)nullptr,
[&](cnnlHandle_t _handle) { [&](cnnlHandle_t _handle) {
cnnlGetBatchMatMulAlgoHeuristic( CHECK_BANG(
cnnlGetBatchMatMulAlgoHeuristic(
_handle, _handle,
op, a, b, c, op, a, b, c,
NULL, 1, &algoResult, &count); NULL, 1, &algoResult, &count)
}); );
return INFINI_STATUS_SUCCESS;
}));
size_t workspace_size; size_t workspace_size;
cnnlGetBatchMatMulHeuristicResult(algoResult, algo, &workspace_size); CHECK_BANG(cnnlGetBatchMatMulHeuristicResult(algoResult, algo, &workspace_size));
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
dtype, info, workspace_size, dtype, info, workspace_size,
...@@ -133,10 +137,10 @@ infiniStatus_t Descriptor::calculate( ...@@ -133,10 +137,10 @@ infiniStatus_t Descriptor::calculate(
if (_info.is_transed) { if (_info.is_transed) {
std::swap(a, b); std::swap(a, b);
} }
_opaque->internal->use_cnnl( CHECK_STATUS(_opaque->internal->use_cnnl(
(cnrtQueue_t)stream, (cnrtQueue_t)stream,
[&](cnnlHandle_t handle) { [&](cnnlHandle_t handle) {
cnnlBatchMatMulBCast_v2( CHECK_BANG(cnnlBatchMatMulBCast_v2(
handle, handle,
_opaque->op, _opaque->op,
_opaque->algo, _opaque->algo,
...@@ -146,9 +150,9 @@ infiniStatus_t Descriptor::calculate( ...@@ -146,9 +150,9 @@ infiniStatus_t Descriptor::calculate(
&beta, &beta,
_opaque->c, c, _opaque->c, c,
workspace, workspace,
workspace_size); workspace_size));
}); return INFINI_STATUS_SUCCESS;
cnrtQueueSync((cnrtQueue_t)stream); }));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
......
...@@ -102,12 +102,8 @@ cnrtMemTransDir_t toBangMemcpyKind(infinirtMemcpyKind_t kind) { ...@@ -102,12 +102,8 @@ cnrtMemTransDir_t toBangMemcpyKind(infinirtMemcpyKind_t kind) {
return cnrtMemcpyHostToDev; return cnrtMemcpyHostToDev;
case INFINIRT_MEMCPY_D2H: case INFINIRT_MEMCPY_D2H:
return cnrtMemcpyDevToHost; return cnrtMemcpyDevToHost;
// Note: Bang has two types of D2D types,
// 1. cnrtMemcpyDevToDev: which is copy in a single device, and
// 2. cnrtMemcpyPeerToPeer: which is from a device to another.
// Here, cnrtMemcpyNoDirection is placed.
case INFINIRT_MEMCPY_D2D: case INFINIRT_MEMCPY_D2D:
return cnrtMemcpyNoDirection; return cnrtMemcpyDevToDev;
case INFINIRT_MEMCPY_H2H: case INFINIRT_MEMCPY_H2H:
return cnrtMemcpyHostToHost; return cnrtMemcpyHostToHost;
default: default:
...@@ -125,11 +121,15 @@ infiniStatus_t memcpyAsync(void *dst, const void *src, size_t size, infinirtMemc ...@@ -125,11 +121,15 @@ infiniStatus_t memcpyAsync(void *dst, const void *src, size_t size, infinirtMemc
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
// Does not support async malloc. Use blocking-style malloc instead
infiniStatus_t mallocAsync(void **p_ptr, size_t size, infinirtStream_t stream) { infiniStatus_t mallocAsync(void **p_ptr, size_t size, infinirtStream_t stream) {
return INFINI_STATUS_NOT_IMPLEMENTED; CHECK_BANGRT(cnrtMalloc(p_ptr, size));
return INFINI_STATUS_SUCCESS;
} }
// Does not support async free. Use blocking-style free instead
infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) { infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) {
return INFINI_STATUS_NOT_IMPLEMENTED; CHECK_BANGRT(cnrtFree(ptr));
return INFINI_STATUS_SUCCESS;
} }
} // namespace infinirt::bang } // namespace infinirt::bang
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment