Commit 7a095dd8 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

remove unnecessary code.

parent 1871d141
...@@ -470,17 +470,6 @@ struct onnx_parser ...@@ -470,17 +470,6 @@ struct onnx_parser
transb = parse_value(attributes.at("transB")).at<bool>(); transb = parse_value(attributes.at("transB")).at<bool>();
} }
// beginning or end of both args have dimension 1, need to squeeze
// before calling gemm, then doing unsqueeze after getting results
std::size_t num_squeeze = args[0]->get_shape().lens().size();
if(num_squeeze > 2)
{
std::vector<int64_t> vec_axises(num_squeeze - 2);
std::iota(vec_axises.begin(), vec_axises.end(), 0);
args[0] = prog.add_instruction(op::squeeze{vec_axises}, args[0]);
args[1] = prog.add_instruction(op::squeeze{vec_axises}, args[1]);
}
std::vector<int64_t> perm = {1, 0}; std::vector<int64_t> perm = {1, 0};
auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0]; auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1]; auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
...@@ -489,13 +478,6 @@ struct onnx_parser ...@@ -489,13 +478,6 @@ struct onnx_parser
if(beta != 0.f) if(beta != 0.f)
{ {
auto l3 = prog.add_instruction(op::dot{alpha}, l1, l2); auto l3 = prog.add_instruction(op::dot{alpha}, l1, l2);
if(num_squeeze > 2)
{
std::vector<int64_t> vec_axises(num_squeeze - 2);
std::iota(vec_axises.begin(), vec_axises.end(), 0);
l3 = prog.add_instruction(op::unsqueeze{vec_axises}, l3);
}
auto l4 = args[2]; auto l4 = args[2];
if(l4->get_shape().scalar()) // ignore args[2] (no C value added to alpha*A*B) if(l4->get_shape().scalar()) // ignore args[2] (no C value added to alpha*A*B)
return l3; return l3;
...@@ -510,12 +492,6 @@ struct onnx_parser ...@@ -510,12 +492,6 @@ struct onnx_parser
} }
auto dot_res = prog.add_instruction(op::dot{alpha, beta}, l1, l2); auto dot_res = prog.add_instruction(op::dot{alpha, beta}, l1, l2);
if(num_squeeze > 2)
{
std::vector<int64_t> vec_axises(num_squeeze - 2);
std::iota(vec_axises.begin(), vec_axises.end(), 0);
dot_res = prog.add_instruction(op::unsqueeze{vec_axises}, dot_res);
}
return dot_res; return dot_res;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment