Commit 932f3a28 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fixed the issue in the cpu implementation for the inner product

parent f6d77130
...@@ -441,15 +441,6 @@ struct cpu_gemm ...@@ -441,15 +441,6 @@ struct cpu_gemm
} }
// 2 input cases // 2 input cases
// all args are scalar
if(output_shape.scalar())
{
visit_all(result, args[0], args[1])(
[&](auto res, auto a, auto b) { res[0] = op.alpha * a[0] * b[0]; });
return result;
}
// first argument is 1-dim, pre-pend 1 at beginning // first argument is 1-dim, pre-pend 1 at beginning
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens(); auto b_lens = args[1].get_shape().lens();
...@@ -461,8 +452,11 @@ struct cpu_gemm ...@@ -461,8 +452,11 @@ struct cpu_gemm
is_a_prepended = true; is_a_prepended = true;
a_lens.insert(a_lens.begin(), 1); a_lens.insert(a_lens.begin(), 1);
out_lens.push_back(1); out_lens.push_back(1);
if (out_lens.size() > 1)
{
std::swap(*out_lens.rbegin(), *(out_lens.rbegin() + 1)); std::swap(*out_lens.rbegin(), *(out_lens.rbegin() + 1));
} }
}
bool is_b_appended = false; bool is_b_appended = false;
if(b_lens.size() == 1) if(b_lens.size() == 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment