Commit 7a095dd8 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

remove unnecessary code.

parent 1871d141
......@@ -470,17 +470,6 @@ struct onnx_parser
transb = parse_value(attributes.at("transB")).at<bool>();
}
// beginning or end of both args have dimension 1, need to squeeze
// before calling gemm, then doing unsqueeze after getting results
std::size_t num_squeeze = args[0]->get_shape().lens().size();
if(num_squeeze > 2)
{
std::vector<int64_t> vec_axises(num_squeeze - 2);
std::iota(vec_axises.begin(), vec_axises.end(), 0);
args[0] = prog.add_instruction(op::squeeze{vec_axises}, args[0]);
args[1] = prog.add_instruction(op::squeeze{vec_axises}, args[1]);
}
std::vector<int64_t> perm = {1, 0};
auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
......@@ -489,13 +478,6 @@ struct onnx_parser
if(beta != 0.f)
{
auto l3 = prog.add_instruction(op::dot{alpha}, l1, l2);
if(num_squeeze > 2)
{
std::vector<int64_t> vec_axises(num_squeeze - 2);
std::iota(vec_axises.begin(), vec_axises.end(), 0);
l3 = prog.add_instruction(op::unsqueeze{vec_axises}, l3);
}
auto l4 = args[2];
if(l4->get_shape().scalar()) // ignore args[2] (no C value added to alpha*A*B)
return l3;
......@@ -510,12 +492,6 @@ struct onnx_parser
}
auto dot_res = prog.add_instruction(op::dot{alpha, beta}, l1, l2);
if(num_squeeze > 2)
{
std::vector<int64_t> vec_axises(num_squeeze - 2);
std::iota(vec_axises.begin(), vec_axises.end(), 0);
dot_res = prog.add_instruction(op::unsqueeze{vec_axises}, dot_res);
}
return dot_res;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment