Commit b658a9f5 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

remove unnecessary changes

parent ad3c6d0d
......@@ -30,7 +30,7 @@ struct multibroadcast
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
check_shapes{inputs, *this}.has(1);
auto t = inputs.at(0).type();
auto input = inputs.at(0);
......
......@@ -214,15 +214,6 @@ struct onnx_parser
auto s0 = arg0->get_shape().lens();
auto s1 = arg1->get_shape().lens();
auto out_lens = compute_broadcasted_lens(s0, s1);
if(!arg0->get_shape().standard())
{
arg0 = prog.add_instruction(op::contiguous{}, arg0);
}
if(!arg1->get_shape().standard())
{
arg1 = prog.add_instruction(op::contiguous{}, arg1);
}
auto l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0);
auto l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1);
return prog.add_instruction(x, l0, l1);
......@@ -643,10 +634,6 @@ struct onnx_parser
auto l3_lens = l3->get_shape().lens();
if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
{
if(!args[2]->get_shape().standard())
{
args[2] = prog.add_instruction(op::contiguous{}, args[2]);
}
l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]);
}
return prog.add_instruction(op::dot{alpha, beta}, l1, l2, l3);
......@@ -696,18 +683,10 @@ struct onnx_parser
l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
if(l0_lens != l0_broadcasted_lens)
{
if(!l0->get_shape().standard())
{
l0 = prog.add_instruction(op::contiguous{}, l0);
}
bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
}
if(l1_lens != l1_broadcasted_lens)
{
if(!l1->get_shape().standard())
{
l1 = prog.add_instruction(op::contiguous{}, l1);
}
bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
}
}
......@@ -1039,11 +1018,6 @@ struct onnx_parser
std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims);
if(!args[0]->get_shape().standard())
{
args[0] = prog.add_instruction(op::contiguous{}, args[0]);
}
return prog.add_instruction(op::multibroadcast{out_lens}, args[0]);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment