Commit 1c551349 authored by Zimin Li's avatar Zimin Li
Browse files

Add CHECK_BANG to cnnl functions in matmul_bang.cc and changed the return type...

Add CHECK_BANG to cnnl functions in matmul_bang.cc and changed the return type of setMatrixTensorEx()
parent 7f2509b8
...@@ -20,7 +20,7 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & { ...@@ -20,7 +20,7 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & {
infiniStatus_t Handle::Internal::useCnnl(cnrtQueue_t queue, const Fn<cnnlHandle_t> &f) const { infiniStatus_t Handle::Internal::useCnnl(cnrtQueue_t queue, const Fn<cnnlHandle_t> &f) const {
auto handle = cnnl_handles.pop(); auto handle = cnnl_handles.pop();
if (!handle) { if (!handle) {
cnnlCreate(&(*handle)); CHECK_BANG(cnnlCreate(&(*handle)));
} }
CHECK_BANG(cnnlSetQueue(*handle, queue)); CHECK_BANG(cnnlSetQueue(*handle, queue));
CHECK_STATUS(f(*handle)); CHECK_STATUS(f(*handle));
......
...@@ -23,7 +23,7 @@ struct Descriptor::Opaque { ...@@ -23,7 +23,7 @@ struct Descriptor::Opaque {
} }
}; };
static void setMatrixTensorEx( static infiniStatus_t setMatrixTensorEx(
cnnlTensorDescriptor_t desc, cnnlTensorDescriptor_t desc,
const BlasMatrix &matrix, infiniDtype_t dtype, const BlasMatrix &matrix, infiniDtype_t dtype,
bool trans = false) { bool trans = false) {
...@@ -39,20 +39,21 @@ static void setMatrixTensorEx( ...@@ -39,20 +39,21 @@ static void setMatrixTensorEx(
case 3: { case 3: {
std::vector<int> dim_size = {batch, rows, cols}; std::vector<int> dim_size = {batch, rows, cols};
std::vector<int> dim_stride = {stride, row_stride, col_stride}; std::vector<int> dim_stride = {stride, row_stride, col_stride};
cnnlSetTensorDescriptorEx( CHECK_BANG(cnnlSetTensorDescriptorEx(
desc, CNNL_LAYOUT_ARRAY, desc, CNNL_LAYOUT_ARRAY,
device::bang::getCnnlDtype(dtype), dim_size.size(), device::bang::getCnnlDtype(dtype), dim_size.size(),
dim_size.data(), dim_stride.data()); dim_size.data(), dim_stride.data()));
} break; } break;
case 2: { case 2: {
std::vector<int> dim_size = {rows, cols}; std::vector<int> dim_size = {rows, cols};
std::vector<int> dim_stride = {row_stride, col_stride}; std::vector<int> dim_stride = {row_stride, col_stride};
cnnlSetTensorDescriptorEx( CHECK_BANG(cnnlSetTensorDescriptorEx(
desc, CNNL_LAYOUT_ARRAY, desc, CNNL_LAYOUT_ARRAY,
device::bang::getCnnlDtype(dtype), dim_size.size(), device::bang::getCnnlDtype(dtype), dim_size.size(),
dim_size.data(), dim_stride.data()); dim_size.data(), dim_stride.data()));
} break; } break;
} }
return INFINI_STATUS_SUCCESS;
} }
Descriptor::~Descriptor() { Descriptor::~Descriptor() {
...@@ -79,9 +80,9 @@ infiniStatus_t Descriptor::create( ...@@ -79,9 +80,9 @@ infiniStatus_t Descriptor::create(
} }
cnnlTensorDescriptor_t a, b, c; cnnlTensorDescriptor_t a, b, c;
cnnlCreateTensorDescriptor(&a); CHECK_BANG(cnnlCreateTensorDescriptor(&a));
cnnlCreateTensorDescriptor(&b); CHECK_BANG(cnnlCreateTensorDescriptor(&b));
cnnlCreateTensorDescriptor(&c); CHECK_BANG(cnnlCreateTensorDescriptor(&c));
setMatrixTensorEx(a, info.a_matrix, a_desc->dtype()); setMatrixTensorEx(a, info.a_matrix, a_desc->dtype());
setMatrixTensorEx(b, info.b_matrix, b_desc->dtype()); setMatrixTensorEx(b, info.b_matrix, b_desc->dtype());
...@@ -90,18 +91,20 @@ infiniStatus_t Descriptor::create( ...@@ -90,18 +91,20 @@ infiniStatus_t Descriptor::create(
cnnlMatMulDescriptor_t op; cnnlMatMulDescriptor_t op;
cnnlMatMulAlgo_t algo; cnnlMatMulAlgo_t algo;
cnnlMatMulHeuristicResult_t algoResult; cnnlMatMulHeuristicResult_t algoResult;
cnnlMatMulDescCreate(&op); CHECK_BANG(cnnlMatMulDescCreate(&op));
cnnlMatMulAlgoCreate(&algo); CHECK_BANG(cnnlMatMulAlgoCreate(&algo));
cnnlCreateMatMulHeuristicResult(&algoResult); CHECK_BANG(cnnlCreateMatMulHeuristicResult(&algoResult));
int32_t use_stride = true; int32_t use_stride = true;
cnnlSetMatMulDescAttr( CHECK_BANG(cnnlSetMatMulDescAttr(
op, op,
CNNL_MATMUL_USE_STRIDE, CNNL_MATMUL_USE_STRIDE,
&use_stride, &use_stride,
sizeof(int32_t)); sizeof(int32_t)));
int count = 0; int count = 0;
CHECK_STATUS(handle->internal()->useCnnl((cnrtQueue_t) nullptr, CHECK_STATUS(
handle->internal()->useCnnl(
(cnrtQueue_t) nullptr,
[&](cnnlHandle_t _handle) { [&](cnnlHandle_t _handle) {
CHECK_BANG( CHECK_BANG(
cnnlGetBatchMatMulAlgoHeuristic( cnnlGetBatchMatMulAlgoHeuristic(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment