Unverified Commit d0e8bb1a authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Tf multi output support (#760)



* fix relu6

* add more transposes

* add multi output

* formatting

* add tests

* formatting

* fix tests

* change to_nchw for outputs

* add python api

* fix cppcheck

* remove variable

* fix lambda

* add multi_output test

* add more tests and merge

* fix help message
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
parent ebf8bd20
......@@ -43,6 +43,7 @@ struct loader
std::string output_type;
std::string output;
std::vector<std::string> param_dims;
std::vector<std::string> output_names;
void parse(argument_parser& ap)
{
......@@ -65,6 +66,12 @@ struct loader
ap.help("Dim of a parameter (format: \"@name d1 d2 dn\")"),
ap.append(),
ap.nargs(2));
ap(output_names,
{"--output-names"},
ap.help("Names of node output (format: \"name_1 name_2 name_n\")"),
ap.append(),
ap.nargs(2));
ap(optimize, {"--optimize", "-O"}, ap.help("Optimize when reading"), ap.set_value(true));
ap(output_type,
{"--graphviz", "-g"},
......@@ -106,12 +113,24 @@ struct loader
return map_input_dims;
}
static auto parse_output_names(const std::vector<std::string>& output_names_info)
{
std::vector<std::string> output_node_names;
std::transform(output_names_info.begin(),
output_names_info.end(),
std::back_inserter(output_node_names),
[&](auto x) { return value_parser<std::string>::apply(x); });
return output_node_names;
}
program load()
{
program p;
if(model.empty())
{
auto map_input_dims = parse_param_dims(param_dims);
auto map_input_dims = parse_param_dims(param_dims);
auto output_node_names = parse_output_names(output_names);
if(file_type.empty())
{
if(ends_with(file, ".onnx"))
......@@ -135,7 +154,7 @@ struct loader
}
else if(file_type == "tf")
{
p = parse_tf(file, tf_options{is_nhwc, batch, map_input_dims});
p = parse_tf(file, tf_options{is_nhwc, batch, map_input_dims, output_node_names});
}
else if(file_type == "json")
{
......
......@@ -14,6 +14,7 @@ struct tf_options
unsigned int batch_size = 1;
/// Explicitly specify the dims of an input
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {};
std::vector<std::string> output_node_names = {};
};
/// Create a program from a tf pb file (default is nhwc format)
......
......@@ -305,13 +305,21 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("name", &migraphx::operation::name);
m.def("parse_tf",
[](const std::string& filename, bool is_nhwc, unsigned int batch_size) {
return migraphx::parse_tf(filename, migraphx::tf_options{is_nhwc, batch_size});
[](const std::string& filename,
bool is_nhwc,
unsigned int batch_size,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
std::vector<std::string> output_names) {
return migraphx::parse_tf(
filename,
migraphx::tf_options{is_nhwc, batch_size, map_input_dims, output_names});
},
"Parse tf protobuf (default format is nhwc)",
py::arg("filename"),
py::arg("is_nhwc") = true,
py::arg("batch_size") = 1);
py::arg("is_nhwc") = true,
py::arg("batch_size") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("output_names") = std::vector<std::string>());
m.def("parse_onnx",
[](const std::string& filename,
......
......@@ -54,6 +54,7 @@ struct tf_parser
const tf_parser&, const node_info&, std::vector<instruction_ref>)>;
node_map nodes;
std::vector<tensorflow::NodeDef> input_nodes;
std::vector<std::string> output_node_names;
std::unordered_map<std::string, instruction_ref> instructions;
program prog = program();
module* mm = prog.get_main_module();
......
......@@ -17,9 +17,10 @@ program parse_tf(const std::string& name, const tf_options& options)
{
std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
tf::tf_parser parser;
parser.is_nhwc = options.is_nhwc;
parser.batch_size = options.batch_size;
parser.map_input_dims = options.map_input_dims;
parser.is_nhwc = options.is_nhwc;
parser.batch_size = options.batch_size;
parser.map_input_dims = options.map_input_dims;
parser.output_node_names = options.output_node_names;
#ifndef NDEBUG
// Log the program when it can't be parsed
......@@ -35,7 +36,6 @@ program parse_tf(const std::string& name, const tf_options& options)
#else
parser.parse_from(input);
#endif
parser.to_nchw(std::prev(parser.mm->end()));
return std::move(parser.prog);
}
......
......@@ -286,9 +286,30 @@ void tf_parser::parse_graph(const tensorflow::GraphDef& graph)
{
this->parse_node(p.first);
}
// Needs to add a ret instruction at the end of
// the program
auto last_ins = std::prev(mm->end());
if(last_ins != mm->end())
{
// Needs to add a ret instruction at the end of
// the program
if(output_node_names.empty())
{
mm->add_return({to_nchw(last_ins)});
}
else
{
std::vector<instruction_ref> output_ins;
std::transform(output_node_names.begin(),
output_node_names.end(),
std::back_inserter(output_ins),
[&](auto output_name) {
if(not contains(instructions, output_name))
MIGRAPHX_THROW("PARSE_TF: output name " + output_name +
" not found in graph!");
return this->to_nchw(instructions[output_name]);
});
mm->add_return(output_ins);
}
}
}
void tf_parser::parse_node(const std::string& name)
......
......@@ -355,6 +355,16 @@ def mul_test(g1):
tf.multiply(g1_input, g2_input, name='mul1')
@tf_test
def multi_output_test(g1):
with g1.as_default():
g1_input = tf.compat.v1.placeholder(tf.float32,
shape=(1, 3, 16, 16),
name='0')
tf.nn.relu(g1_input, 'relu')
tf.tanh(g1_input, 'tanh')
@tf_test
def noop_test(g1):
with g1.as_default():
......@@ -645,6 +655,7 @@ if __name__ == '__main__':
mean_test()
mean_test_nhwc()
mul_test()
multi_output_test()
noop_test()
onehot_test()
pack_test()
......
:
0 Placeholder*
dtype0*
shape:

reluRelu0*
T0

tanhTanh0*
T0"&
\ No newline at end of file
......@@ -19,9 +19,11 @@
migraphx::program
parse_tf(const std::string& name,
bool is_nhwc,
const std::unordered_map<std::string, std::vector<std::size_t>>& dim_params = {})
const std::unordered_map<std::string, std::vector<std::size_t>>& dim_params = {},
const std::vector<std::string>& output_node_names = {})
{
return migraphx::parse_tf(name, migraphx::tf_options{is_nhwc, 1, dim_params});
return migraphx::parse_tf(name,
migraphx::tf_options{is_nhwc, 1, dim_params, output_node_names});
}
migraphx::program optimize_tf(const std::string& name, bool is_nhwc)
......@@ -33,6 +35,14 @@ migraphx::program optimize_tf(const std::string& name, bool is_nhwc)
{migraphx::simplify_reshapes{},
migraphx::dead_code_elimination{},
migraphx::eliminate_identity{}});
// remove the last return instruction
auto last_ins = std::prev(mm->end());
if(last_ins != mm->end())
if(last_ins->name() == "@return")
{
mm->remove_instruction(last_ins);
}
return prog;
}
......@@ -86,7 +96,8 @@ TEST_CASE(argmax_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 5, 6, 7}});
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}});
auto ins = mm->add_instruction(migraphx::make_op("argmax", {{"axis", 2}}), l0);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins);
auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins);
mm->add_return({l1});
auto prog = parse_tf("argmax_test.pb", false, {{"0", {4, 5, 6, 7}}});
EXPECT(p == prog);
......@@ -100,7 +111,8 @@ TEST_CASE(argmin_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}});
auto ins = mm->add_instruction(migraphx::make_op("argmin", {{"axis", 2}}), l0);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins);
auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins);
mm->add_return({l1});
auto prog = parse_tf("argmin_test.pb", false);
EXPECT(p == prog);
......@@ -502,6 +514,22 @@ TEST_CASE(mul_test)
EXPECT(p == prog);
}
TEST_CASE(multi_output_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto l1 = mm->add_instruction(migraphx::make_op("relu"), l0);
auto l2 = mm->add_instruction(migraphx::make_op("tanh"), l0);
mm->add_return({l1, l2});
EXPECT(test::throws([&] { parse_tf("multi_output_test.pb", false, {}, {"relu", "relu6"}); }));
auto prog = parse_tf("multi_output_test.pb", false, {}, {"relu", "tanh"});
EXPECT(p == prog);
}
TEST_CASE(onehot_test)
{
migraphx::program p;
......@@ -755,8 +783,8 @@ TEST_CASE(split_test)
auto l3 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 20}}, {"ends", {5, 30}}}), l0);
mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2);
mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3);
auto l4 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3);
mm->add_return({l4});
auto prog = parse_tf("split_test.pb", false);
EXPECT(p == prog);
......@@ -770,8 +798,8 @@ TEST_CASE(split_test_one_output)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
mm->add_literal(1); // num_splits
mm->add_literal(1); // split axis
mm->add_instruction(migraphx::make_op("identity"), l0);
auto l1 = mm->add_instruction(migraphx::make_op("identity"), l0);
mm->add_return({l1});
auto prog = parse_tf("split_test_one_output.pb", false);
EXPECT(p == prog);
......@@ -797,8 +825,8 @@ TEST_CASE(split_test_vector_as_input)
auto l3 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 19}}, {"ends", {5, 30}}}), l0);
mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2);
mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3);
auto l4 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3);
mm->add_return({l4});
auto prog = parse_tf("split_test_vector_as_input.pb", false);
EXPECT(p == prog);
......@@ -884,7 +912,8 @@ TEST_CASE(stridedslice_masks_test)
auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0);
auto l2 = mm->add_instruction(op, l1);
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 3, 1, 2}}}), l2);
auto l3 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 3, 1, 2}}}), l2);
mm->add_return({l3});
auto prog = parse_tf("stridedslice_masks_test.pb", true);
EXPECT(p == prog);
......@@ -897,7 +926,8 @@ TEST_CASE(sub_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
mm->add_instruction(migraphx::make_op("sub"), l0, l1);
auto l2 = mm->add_instruction(migraphx::make_op("sub"), l0, l1);
mm->add_return({l2});
auto prog = parse_tf("sub_test.pb", false);
EXPECT(p == prog);
......@@ -908,10 +938,10 @@ TEST_CASE(tanh_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
mm->add_instruction(migraphx::make_op("sub"), l0, l1);
auto prog = parse_tf("sub_test.pb", false);
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto l1 = mm->add_instruction(migraphx::make_op("tanh"), l0);
mm->add_return({l1});
auto prog = parse_tf("tanh_test.pb", false);
EXPECT(p == prog);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment