Commit 8bf2940a authored by Khalique's avatar Khalique
Browse files

continued progress, fixed issues with conv, pooling, added reshape

parent cbd244d1
...@@ -48,6 +48,8 @@ struct tf_parser ...@@ -48,6 +48,8 @@ struct tf_parser
add_mem_op("Const", &tf_parser::parse_constant); add_mem_op("Const", &tf_parser::parse_constant);
add_mem_op("Conv2D", &tf_parser::parse_conv); add_mem_op("Conv2D", &tf_parser::parse_conv);
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm); add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
add_mem_op("MaxPool", &tf_parser::parse_pooling);
// add_mem_op("Reshape", &tf_parser::parse_reshape);
} }
template <class F> template <class F>
...@@ -156,7 +158,8 @@ struct tf_parser ...@@ -156,7 +158,8 @@ struct tf_parser
std::size_t axis_idx = attributes.at("N").i(); std::size_t axis_idx = attributes.at("N").i();
std::size_t axis = args[axis_idx]->eval().at<int64_t>(); std::size_t axis = args[axis_idx]->eval().at<int64_t>();
op::concat op{axis}; op::concat op{axis};
return prog.add_instruction(op, std::move(args)); // return only first N arguments (assuming last index is the axis value)
return prog.add_instruction(op, std::vector<instruction_ref>(args.begin(), args.begin() + axis));
} }
instruction_ref parse_constant(const std::string&, instruction_ref parse_constant(const std::string&,
...@@ -232,9 +235,14 @@ struct tf_parser ...@@ -232,9 +235,14 @@ struct tf_parser
op.dilation[1] = dilation[3]; op.dilation[1] = dilation[3];
} }
} }
auto l0 = args[0];
auto l0 = prog.add_instruction(op::transpose{{2, 3, 0, 1}}, args[1]); if (l0->name() == "@param")
return prog.add_instruction(op, {args[0], l0}); {
if(is_nhwc)
l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, l0);
}
auto l1 = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]);
return prog.add_instruction(op, {l0, l1});
} }
instruction_ref parse_pooling(const std::string& name, instruction_ref parse_pooling(const std::string& name,
...@@ -258,7 +266,7 @@ struct tf_parser ...@@ -258,7 +266,7 @@ struct tf_parser
if(contains(attributes, "strides")) if(contains(attributes, "strides"))
{ {
std::vector<std::size_t> stride; std::vector<std::size_t> stride;
copy(attributes.at("stride").list().i(), std::back_inserter(stride)); copy(attributes.at("strides").list().i(), std::back_inserter(stride));
if(stride.size() != 4) if(stride.size() != 4)
{ {
MIGRAPHX_THROW("strides should have 4 values"); MIGRAPHX_THROW("strides should have 4 values");
...@@ -297,6 +305,17 @@ struct tf_parser ...@@ -297,6 +305,17 @@ struct tf_parser
return prog.add_instruction(op, std::move(args)); return prog.add_instruction(op, std::move(args));
} }
instruction_ref
parse_reshape(const std::string&, attribute_map, std::vector<instruction_ref> args)
{
op::reshape op;
if(args.size() != 2)
MIGRAPHX_THROW("reshape needs 2 arguments (input, new_shape)");
literal s = args[1]->get_literal();
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
return prog.add_instruction(op, args[0]);
}
void parse_from(std::istream& is) void parse_from(std::istream& is)
{ {
tensorflow::GraphDef graph; tensorflow::GraphDef graph;
...@@ -321,11 +340,6 @@ struct tf_parser ...@@ -321,11 +340,6 @@ struct tf_parser
std::vector<size_t> dims = parse_dims(input_attrs.at("shape").shape()); std::vector<size_t> dims = parse_dims(input_attrs.at("shape").shape());
shape s = shape{shape_type, dims}; shape s = shape{shape_type, dims};
instructions[name] = prog.add_parameter(name, s); instructions[name] = prog.add_parameter(name, s);
if(is_nhwc)
{
// nhwc to nchw
prog.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, instructions[name]);
}
} }
for(auto&& p : nodes) for(auto&& p : nodes)
{ {
...@@ -339,7 +353,7 @@ struct tf_parser ...@@ -339,7 +353,7 @@ struct tf_parser
{ {
auto&& node = nodes.at(name); auto&& node = nodes.at(name);
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
std::cout << name << std::endl; // std::cout << name << std::endl;
for(auto&& input : node.input()) for(auto&& input : node.input())
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment