Commit 70d2d53a authored by Zimin Li's avatar Zimin Li
Browse files

issue/89 merge _internal.h content to common_bang.h, add cnrtQueueSync in...

issue/89 merge _internal.h content to common_bang.h, add cnrtQueueSync in matmul bang implementation, and fix misc. issues
parent 1c551349
#ifndef __INFINIOP_BANG_INTERNAL_H__
#define __INFINIOP_BANG_INTERNAL_H__
#include "../../../utils.h"
#include "../pool.h"
#include "bang_handle.h"
#include "cnnl.h"
#include "cnrt.h"
#include <functional>
namespace device::bang {
class Handle::Internal {
Pool<cnnlHandle_t> cnnl_handles;
template <typename T>
using Fn = std::function<infiniStatus_t(T)>;
public:
infiniStatus_t useCnnl(cnrtQueue_t queue, const Fn<cnnlHandle_t> &f) const;
};
cnnlDataType_t getCnnlDtype(infiniDtype_t dt);
} // namespace device::bang
#endif // __INFINIOP_BANG_INTERNAL_H__
#include "../../tensor.h"
#include "../pool.h"
#include "_internal.h"
#include "cnnl.h"
#include "common_bang.h"
#include "infiniop/tensor_descriptor.h"
......@@ -51,9 +50,8 @@ cnnlDataType_t getCnnlDtype(infiniDtype_t dt) {
}
}
// set cnnl tensor descriptor without strides11
inline infiniStatus_t setCnnlTensor(cnnlTensorDescriptor_t desc,
const InfiniopTensorDescriptor *layout) {
infiniStatus_t setCnnlTensor(cnnlTensorDescriptor_t desc,
const InfiniopTensorDescriptor *layout) {
std::vector<int> dims(layout->ndim());
for (size_t i = 0; i < layout->ndim(); i++) {
dims[i] = static_cast<int>(layout->shape()[i]);
......@@ -64,9 +62,8 @@ inline infiniStatus_t setCnnlTensor(cnnlTensorDescriptor_t desc,
return INFINI_STATUS_SUCCESS;
}
// set cnnl tensor descriptor with strides
inline infiniStatus_t setCnnlTensorEx(cnnlTensorDescriptor_t desc,
const InfiniopTensorDescriptor *layout) {
infiniStatus_t setCnnlTensorEx(cnnlTensorDescriptor_t desc,
const InfiniopTensorDescriptor *layout) {
std::vector<int> dim_size(layout->ndim()), dim_stride(layout->ndim());
for (size_t i = 0; i < layout->ndim(); i++) {
dim_size[i] = static_cast<int>(layout->shape()[i]);
......
......@@ -2,18 +2,36 @@
#define __COMMON_BANG_H__
#include "../../../utils.h"
#include "../pool.h"
#include "bang_handle.h"
#include "cnnl.h"
#include "cnrt.h"
#include <functional>
// the maximum NRAM memory is 1024 * 768
#define NRAM_MAX_SIZE (1024 * 256)
#define CHECK_BANG(API) CHECK_INTERNAL(API, CNNL_STATUS_SUCCESS)
#define GDRAM_MAX_SIZE (1024 * 1024 * 1024)
namespace device::bang {
#ifdef __cplusplus
extern "C" {
#endif
#define CHECK_BANG(API) CHECK_INTERNAL(API, CNNL_STATUS_SUCCESS)
#ifdef __cplusplus
class Handle::Internal {
Pool<cnnlHandle_t> cnnl_handles;
template <typename T>
using Fn = std::function<infiniStatus_t(T)>;
public:
infiniStatus_t useCnnl(cnrtQueue_t queue, const Fn<cnnlHandle_t> &f) const;
};
#endif
cnnlDataType_t getCnnlDtype(infiniDtype_t dt);
// set cnnl tensor descriptor without strides
infiniStatus_t setCnnlTensor(cnnlTensorDescriptor_t desc,
const InfiniopTensorDescriptor *layout);
// set cnnl tensor descriptor with strides
infiniStatus_t setCnnlTensorEx(cnnlTensorDescriptor_t desc,
const InfiniopTensorDescriptor *layout);
} // namespace device::bang
#endif // __COMMON_BANG_H__
#include "matmul_bang.h"
#include "../../../devices/bang/_internal.h"
#include "../../../devices/bang/bang_handle.h"
#include "../../../devices/bang/common_bang.h"
#include <cnnl_extra.h>
......@@ -84,9 +82,9 @@ infiniStatus_t Descriptor::create(
CHECK_BANG(cnnlCreateTensorDescriptor(&b));
CHECK_BANG(cnnlCreateTensorDescriptor(&c));
setMatrixTensorEx(a, info.a_matrix, a_desc->dtype());
setMatrixTensorEx(b, info.b_matrix, b_desc->dtype());
setMatrixTensorEx(c, info.c_matrix, c_desc->dtype());
CHECK_STATUS(setMatrixTensorEx(a, info.a_matrix, a_desc->dtype()));
CHECK_STATUS(setMatrixTensorEx(b, info.b_matrix, b_desc->dtype()));
CHECK_STATUS(setMatrixTensorEx(c, info.c_matrix, c_desc->dtype()));
cnnlMatMulDescriptor_t op;
cnnlMatMulAlgo_t algo;
......@@ -154,6 +152,7 @@ infiniStatus_t Descriptor::calculate(
workspace_size));
return INFINI_STATUS_SUCCESS;
}));
cnrtQueueSync((cnrtQueue_t)stream);
return INFINI_STATUS_SUCCESS;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment