Commit b658a9f5 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

remove unnecessary changes

parent ad3c6d0d
...@@ -30,7 +30,7 @@ struct multibroadcast ...@@ -30,7 +30,7 @@ struct multibroadcast
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1);
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
auto input = inputs.at(0); auto input = inputs.at(0);
......
...@@ -214,15 +214,6 @@ struct onnx_parser ...@@ -214,15 +214,6 @@ struct onnx_parser
auto s0 = arg0->get_shape().lens(); auto s0 = arg0->get_shape().lens();
auto s1 = arg1->get_shape().lens(); auto s1 = arg1->get_shape().lens();
auto out_lens = compute_broadcasted_lens(s0, s1); auto out_lens = compute_broadcasted_lens(s0, s1);
if(!arg0->get_shape().standard())
{
arg0 = prog.add_instruction(op::contiguous{}, arg0);
}
if(!arg1->get_shape().standard())
{
arg1 = prog.add_instruction(op::contiguous{}, arg1);
}
auto l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0); auto l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0);
auto l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1); auto l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1);
return prog.add_instruction(x, l0, l1); return prog.add_instruction(x, l0, l1);
...@@ -643,10 +634,6 @@ struct onnx_parser ...@@ -643,10 +634,6 @@ struct onnx_parser
auto l3_lens = l3->get_shape().lens(); auto l3_lens = l3->get_shape().lens();
if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end())) if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
{ {
if(!args[2]->get_shape().standard())
{
args[2] = prog.add_instruction(op::contiguous{}, args[2]);
}
l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]); l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]);
} }
return prog.add_instruction(op::dot{alpha, beta}, l1, l2, l3); return prog.add_instruction(op::dot{alpha, beta}, l1, l2, l3);
...@@ -696,18 +683,10 @@ struct onnx_parser ...@@ -696,18 +683,10 @@ struct onnx_parser
l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end()); l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
if(l0_lens != l0_broadcasted_lens) if(l0_lens != l0_broadcasted_lens)
{ {
if(!l0->get_shape().standard())
{
l0 = prog.add_instruction(op::contiguous{}, l0);
}
bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0); bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
} }
if(l1_lens != l1_broadcasted_lens) if(l1_lens != l1_broadcasted_lens)
{ {
if(!l1->get_shape().standard())
{
l1 = prog.add_instruction(op::contiguous{}, l1);
}
bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1); bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
} }
} }
...@@ -1039,11 +1018,6 @@ struct onnx_parser ...@@ -1039,11 +1018,6 @@ struct onnx_parser
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); }); arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims); auto out_lens = compute_broadcasted_lens(in_lens, dims);
if(!args[0]->get_shape().standard())
{
args[0] = prog.add_instruction(op::contiguous{}, args[0]);
}
return prog.add_instruction(op::multibroadcast{out_lens}, args[0]); return prog.add_instruction(op::multibroadcast{out_lens}, args[0]);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment