Commit b9c81e2d authored by Khalique's avatar Khalique
Browse files

fix logic for conv weights

parent 8a4f5778
...@@ -300,16 +300,22 @@ struct tf_parser ...@@ -300,16 +300,22 @@ struct tf_parser
op.dilation[0] = dilation[2]; op.dilation[0] = dilation[2];
op.dilation[1] = dilation[3]; op.dilation[1] = dilation[3];
} }
auto l0 = args[1]; auto weights = args[1];
// check if weights are from a constant // check if weights are from a constant
if(l0->inputs().at(0)->name() == "@literal" and is_nhwc)
if(weights->name() != "@param")
{
if(is_nhwc)
{ {
l0 = prog.add_instruction(op::transpose{{1, 3, 0, 2}}, args[1]); weights = prog.add_instruction(op::transpose{{1, 3, 0, 2}}, args[1]);
}
else
{
weights = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]);
}
} }
else if(l0->name() != "@param")
MIGRAPHX_THROW("cannot infer data format for weights");
return prog.add_instruction(op, {args[0], l0}); return prog.add_instruction(op, {args[0], weights});
} }
instruction_ref parse_pooling(const std::string& name, instruction_ref parse_pooling(const std::string& name,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment