Commit 95b2a528 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix review comments

parent a392d84c
...@@ -206,6 +206,16 @@ struct onnx_parser ...@@ -206,6 +206,16 @@ struct onnx_parser
return out_lens; return out_lens;
} }
instruction_ref make_contiguous(instruction_ref ins)
{
if (ins->get_shape().standard())
{
return ins;
}
return prog.add_instruction(op::contiguous{}, ins);
}
template <class T> template <class T>
instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x) instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
{ {
...@@ -437,11 +447,7 @@ struct onnx_parser ...@@ -437,11 +447,7 @@ struct onnx_parser
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
} }
if(!args[0]->get_shape().standard()) args[0] = make_contiguous(args[0]);
{
args[0] = prog.add_instruction(op::contiguous{}, args[0]);
}
return prog.add_instruction(op, args[0]); return prog.add_instruction(op, args[0]);
} }
...@@ -490,15 +496,8 @@ struct onnx_parser ...@@ -490,15 +496,8 @@ struct onnx_parser
{ {
axis = parse_value(attributes.at("axis")).at<int>(); axis = parse_value(attributes.at("axis")).at<int>();
} }
args[0] = make_contiguous(args[0]);
if(!args[0]->get_shape().standard()) args[1] = make_contiguous(args[1]);
{
args[0] = prog.add_instruction(op::contiguous{}, args[0]);
}
if(!args[1]->get_shape().standard())
{
args[1] = prog.add_instruction(op::contiguous{}, args[1]);
}
op::gather op{axis}; op::gather op{axis};
return prog.add_instruction(op, std::move(args)); return prog.add_instruction(op, std::move(args));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment