Commit 1c551349 authored by Zimin Li's avatar Zimin Li
Browse files

Add CHECK_BANG to cnnl functions in matmul_bang.cc and changed the return type...

Add CHECK_BANG to cnnl functions in matmul_bang.cc and changed the return type of setMatrixTensorEx()
parent 7f2509b8
......@@ -20,7 +20,7 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & {
infiniStatus_t Handle::Internal::useCnnl(cnrtQueue_t queue, const Fn<cnnlHandle_t> &f) const {
auto handle = cnnl_handles.pop();
if (!handle) {
cnnlCreate(&(*handle));
CHECK_BANG(cnnlCreate(&(*handle)));
}
CHECK_BANG(cnnlSetQueue(*handle, queue));
CHECK_STATUS(f(*handle));
......
......@@ -23,7 +23,7 @@ struct Descriptor::Opaque {
}
};
static void setMatrixTensorEx(
static infiniStatus_t setMatrixTensorEx(
cnnlTensorDescriptor_t desc,
const BlasMatrix &matrix, infiniDtype_t dtype,
bool trans = false) {
......@@ -39,20 +39,21 @@ static void setMatrixTensorEx(
case 3: {
std::vector<int> dim_size = {batch, rows, cols};
std::vector<int> dim_stride = {stride, row_stride, col_stride};
cnnlSetTensorDescriptorEx(
CHECK_BANG(cnnlSetTensorDescriptorEx(
desc, CNNL_LAYOUT_ARRAY,
device::bang::getCnnlDtype(dtype), dim_size.size(),
dim_size.data(), dim_stride.data());
dim_size.data(), dim_stride.data()));
} break;
case 2: {
std::vector<int> dim_size = {rows, cols};
std::vector<int> dim_stride = {row_stride, col_stride};
cnnlSetTensorDescriptorEx(
CHECK_BANG(cnnlSetTensorDescriptorEx(
desc, CNNL_LAYOUT_ARRAY,
device::bang::getCnnlDtype(dtype), dim_size.size(),
dim_size.data(), dim_stride.data());
dim_size.data(), dim_stride.data()));
} break;
}
return INFINI_STATUS_SUCCESS;
}
Descriptor::~Descriptor() {
......@@ -79,9 +80,9 @@ infiniStatus_t Descriptor::create(
}
cnnlTensorDescriptor_t a, b, c;
cnnlCreateTensorDescriptor(&a);
cnnlCreateTensorDescriptor(&b);
cnnlCreateTensorDescriptor(&c);
CHECK_BANG(cnnlCreateTensorDescriptor(&a));
CHECK_BANG(cnnlCreateTensorDescriptor(&b));
CHECK_BANG(cnnlCreateTensorDescriptor(&c));
setMatrixTensorEx(a, info.a_matrix, a_desc->dtype());
setMatrixTensorEx(b, info.b_matrix, b_desc->dtype());
......@@ -90,18 +91,20 @@ infiniStatus_t Descriptor::create(
cnnlMatMulDescriptor_t op;
cnnlMatMulAlgo_t algo;
cnnlMatMulHeuristicResult_t algoResult;
cnnlMatMulDescCreate(&op);
cnnlMatMulAlgoCreate(&algo);
cnnlCreateMatMulHeuristicResult(&algoResult);
CHECK_BANG(cnnlMatMulDescCreate(&op));
CHECK_BANG(cnnlMatMulAlgoCreate(&algo));
CHECK_BANG(cnnlCreateMatMulHeuristicResult(&algoResult));
int32_t use_stride = true;
cnnlSetMatMulDescAttr(
CHECK_BANG(cnnlSetMatMulDescAttr(
op,
CNNL_MATMUL_USE_STRIDE,
&use_stride,
sizeof(int32_t));
sizeof(int32_t)));
int count = 0;
CHECK_STATUS(handle->internal()->useCnnl((cnrtQueue_t) nullptr,
CHECK_STATUS(
handle->internal()->useCnnl(
(cnrtQueue_t) nullptr,
[&](cnnlHandle_t _handle) {
CHECK_BANG(
cnnlGetBatchMatMulAlgoHeuristic(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment