Unverified Commit d0e8bb1a authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Tf multi output support (#760)



* fix relu6

* add more transposes

* add multi output

* formatting

* add tests

* formatting

* fix tests

* change to_nchw for outputs

* add python api

* fix cppcheck

* remove variable

* fix lambda

* add multi_output test

* add more tests and merge

* fix help message
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
parent ebf8bd20
...@@ -43,6 +43,7 @@ struct loader ...@@ -43,6 +43,7 @@ struct loader
std::string output_type; std::string output_type;
std::string output; std::string output;
std::vector<std::string> param_dims; std::vector<std::string> param_dims;
std::vector<std::string> output_names;
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
...@@ -65,6 +66,12 @@ struct loader ...@@ -65,6 +66,12 @@ struct loader
ap.help("Dim of a parameter (format: \"@name d1 d2 dn\")"), ap.help("Dim of a parameter (format: \"@name d1 d2 dn\")"),
ap.append(), ap.append(),
ap.nargs(2)); ap.nargs(2));
ap(output_names,
{"--output-names"},
ap.help("Names of node output (format: \"name_1 name_2 name_n\")"),
ap.append(),
ap.nargs(2));
ap(optimize, {"--optimize", "-O"}, ap.help("Optimize when reading"), ap.set_value(true)); ap(optimize, {"--optimize", "-O"}, ap.help("Optimize when reading"), ap.set_value(true));
ap(output_type, ap(output_type,
{"--graphviz", "-g"}, {"--graphviz", "-g"},
...@@ -106,12 +113,24 @@ struct loader ...@@ -106,12 +113,24 @@ struct loader
return map_input_dims; return map_input_dims;
} }
static auto parse_output_names(const std::vector<std::string>& output_names_info)
{
std::vector<std::string> output_node_names;
std::transform(output_names_info.begin(),
output_names_info.end(),
std::back_inserter(output_node_names),
[&](auto x) { return value_parser<std::string>::apply(x); });
return output_node_names;
}
program load() program load()
{ {
program p; program p;
if(model.empty()) if(model.empty())
{ {
auto map_input_dims = parse_param_dims(param_dims); auto map_input_dims = parse_param_dims(param_dims);
auto output_node_names = parse_output_names(output_names);
if(file_type.empty()) if(file_type.empty())
{ {
if(ends_with(file, ".onnx")) if(ends_with(file, ".onnx"))
...@@ -135,7 +154,7 @@ struct loader ...@@ -135,7 +154,7 @@ struct loader
} }
else if(file_type == "tf") else if(file_type == "tf")
{ {
p = parse_tf(file, tf_options{is_nhwc, batch, map_input_dims}); p = parse_tf(file, tf_options{is_nhwc, batch, map_input_dims, output_node_names});
} }
else if(file_type == "json") else if(file_type == "json")
{ {
......
...@@ -14,6 +14,7 @@ struct tf_options ...@@ -14,6 +14,7 @@ struct tf_options
unsigned int batch_size = 1; unsigned int batch_size = 1;
/// Explicitly specify the dims of an input /// Explicitly specify the dims of an input
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {}; std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {};
std::vector<std::string> output_node_names = {};
}; };
/// Create a program from a tf pb file (default is nhwc format) /// Create a program from a tf pb file (default is nhwc format)
......
...@@ -305,13 +305,21 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -305,13 +305,21 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("name", &migraphx::operation::name); .def("name", &migraphx::operation::name);
m.def("parse_tf", m.def("parse_tf",
[](const std::string& filename, bool is_nhwc, unsigned int batch_size) { [](const std::string& filename,
return migraphx::parse_tf(filename, migraphx::tf_options{is_nhwc, batch_size}); bool is_nhwc,
unsigned int batch_size,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
std::vector<std::string> output_names) {
return migraphx::parse_tf(
filename,
migraphx::tf_options{is_nhwc, batch_size, map_input_dims, output_names});
}, },
"Parse tf protobuf (default format is nhwc)", "Parse tf protobuf (default format is nhwc)",
py::arg("filename"), py::arg("filename"),
py::arg("is_nhwc") = true, py::arg("is_nhwc") = true,
py::arg("batch_size") = 1); py::arg("batch_size") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("output_names") = std::vector<std::string>());
m.def("parse_onnx", m.def("parse_onnx",
[](const std::string& filename, [](const std::string& filename,
......
...@@ -54,6 +54,7 @@ struct tf_parser ...@@ -54,6 +54,7 @@ struct tf_parser
const tf_parser&, const node_info&, std::vector<instruction_ref>)>; const tf_parser&, const node_info&, std::vector<instruction_ref>)>;
node_map nodes; node_map nodes;
std::vector<tensorflow::NodeDef> input_nodes; std::vector<tensorflow::NodeDef> input_nodes;
std::vector<std::string> output_node_names;
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
program prog = program(); program prog = program();
module* mm = prog.get_main_module(); module* mm = prog.get_main_module();
......
...@@ -20,6 +20,7 @@ program parse_tf(const std::string& name, const tf_options& options) ...@@ -20,6 +20,7 @@ program parse_tf(const std::string& name, const tf_options& options)
parser.is_nhwc = options.is_nhwc; parser.is_nhwc = options.is_nhwc;
parser.batch_size = options.batch_size; parser.batch_size = options.batch_size;
parser.map_input_dims = options.map_input_dims; parser.map_input_dims = options.map_input_dims;
parser.output_node_names = options.output_node_names;
#ifndef NDEBUG #ifndef NDEBUG
// Log the program when it can't be parsed // Log the program when it can't be parsed
...@@ -35,7 +36,6 @@ program parse_tf(const std::string& name, const tf_options& options) ...@@ -35,7 +36,6 @@ program parse_tf(const std::string& name, const tf_options& options)
#else #else
parser.parse_from(input); parser.parse_from(input);
#endif #endif
parser.to_nchw(std::prev(parser.mm->end()));
return std::move(parser.prog); return std::move(parser.prog);
} }
......
...@@ -286,9 +286,30 @@ void tf_parser::parse_graph(const tensorflow::GraphDef& graph) ...@@ -286,9 +286,30 @@ void tf_parser::parse_graph(const tensorflow::GraphDef& graph)
{ {
this->parse_node(p.first); this->parse_node(p.first);
} }
auto last_ins = std::prev(mm->end());
if(last_ins != mm->end())
{
// Needs to add a ret instruction at the end of // Needs to add a ret instruction at the end of
// the program // the program
if(output_node_names.empty())
{
mm->add_return({to_nchw(last_ins)});
}
else
{
std::vector<instruction_ref> output_ins;
std::transform(output_node_names.begin(),
output_node_names.end(),
std::back_inserter(output_ins),
[&](auto output_name) {
if(not contains(instructions, output_name))
MIGRAPHX_THROW("PARSE_TF: output name " + output_name +
" not found in graph!");
return this->to_nchw(instructions[output_name]);
});
mm->add_return(output_ins);
}
}
} }
void tf_parser::parse_node(const std::string& name) void tf_parser::parse_node(const std::string& name)
......
...@@ -355,6 +355,16 @@ def mul_test(g1): ...@@ -355,6 +355,16 @@ def mul_test(g1):
tf.multiply(g1_input, g2_input, name='mul1') tf.multiply(g1_input, g2_input, name='mul1')
@tf_test
def multi_output_test(g1):
with g1.as_default():
g1_input = tf.compat.v1.placeholder(tf.float32,
shape=(1, 3, 16, 16),
name='0')
tf.nn.relu(g1_input, 'relu')
tf.tanh(g1_input, 'tanh')
@tf_test @tf_test
def noop_test(g1): def noop_test(g1):
with g1.as_default(): with g1.as_default():
...@@ -645,6 +655,7 @@ if __name__ == '__main__': ...@@ -645,6 +655,7 @@ if __name__ == '__main__':
mean_test() mean_test()
mean_test_nhwc() mean_test_nhwc()
mul_test() mul_test()
multi_output_test()
noop_test() noop_test()
onehot_test() onehot_test()
pack_test() pack_test()
......
:
0 Placeholder*
dtype0*
shape:

reluRelu0*
T0

tanhTanh0*
T0"&
\ No newline at end of file
...@@ -19,9 +19,11 @@ ...@@ -19,9 +19,11 @@
migraphx::program migraphx::program
parse_tf(const std::string& name, parse_tf(const std::string& name,
bool is_nhwc, bool is_nhwc,
const std::unordered_map<std::string, std::vector<std::size_t>>& dim_params = {}) const std::unordered_map<std::string, std::vector<std::size_t>>& dim_params = {},
const std::vector<std::string>& output_node_names = {})
{ {
return migraphx::parse_tf(name, migraphx::tf_options{is_nhwc, 1, dim_params}); return migraphx::parse_tf(name,
migraphx::tf_options{is_nhwc, 1, dim_params, output_node_names});
} }
migraphx::program optimize_tf(const std::string& name, bool is_nhwc) migraphx::program optimize_tf(const std::string& name, bool is_nhwc)
...@@ -33,6 +35,14 @@ migraphx::program optimize_tf(const std::string& name, bool is_nhwc) ...@@ -33,6 +35,14 @@ migraphx::program optimize_tf(const std::string& name, bool is_nhwc)
{migraphx::simplify_reshapes{}, {migraphx::simplify_reshapes{},
migraphx::dead_code_elimination{}, migraphx::dead_code_elimination{},
migraphx::eliminate_identity{}}); migraphx::eliminate_identity{}});
// remove the last return instruction
auto last_ins = std::prev(mm->end());
if(last_ins != mm->end())
if(last_ins->name() == "@return")
{
mm->remove_instruction(last_ins);
}
return prog; return prog;
} }
...@@ -86,7 +96,8 @@ TEST_CASE(argmax_test) ...@@ -86,7 +96,8 @@ TEST_CASE(argmax_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 5, 6, 7}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 5, 6, 7}});
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}}); mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}});
auto ins = mm->add_instruction(migraphx::make_op("argmax", {{"axis", 2}}), l0); auto ins = mm->add_instruction(migraphx::make_op("argmax", {{"axis", 2}}), l0);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins); auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins);
mm->add_return({l1});
auto prog = parse_tf("argmax_test.pb", false, {{"0", {4, 5, 6, 7}}}); auto prog = parse_tf("argmax_test.pb", false, {{"0", {4, 5, 6, 7}}});
EXPECT(p == prog); EXPECT(p == prog);
...@@ -100,7 +111,8 @@ TEST_CASE(argmin_test) ...@@ -100,7 +111,8 @@ TEST_CASE(argmin_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}}); mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}});
auto ins = mm->add_instruction(migraphx::make_op("argmin", {{"axis", 2}}), l0); auto ins = mm->add_instruction(migraphx::make_op("argmin", {{"axis", 2}}), l0);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins); auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins);
mm->add_return({l1});
auto prog = parse_tf("argmin_test.pb", false); auto prog = parse_tf("argmin_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
...@@ -502,6 +514,22 @@ TEST_CASE(mul_test) ...@@ -502,6 +514,22 @@ TEST_CASE(mul_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(multi_output_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto l1 = mm->add_instruction(migraphx::make_op("relu"), l0);
auto l2 = mm->add_instruction(migraphx::make_op("tanh"), l0);
mm->add_return({l1, l2});
EXPECT(test::throws([&] { parse_tf("multi_output_test.pb", false, {}, {"relu", "relu6"}); }));
auto prog = parse_tf("multi_output_test.pb", false, {}, {"relu", "tanh"});
EXPECT(p == prog);
}
TEST_CASE(onehot_test) TEST_CASE(onehot_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -755,8 +783,8 @@ TEST_CASE(split_test) ...@@ -755,8 +783,8 @@ TEST_CASE(split_test)
auto l3 = mm->add_instruction( auto l3 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 20}}, {"ends", {5, 30}}}), l0); migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 20}}, {"ends", {5, 30}}}), l0);
mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2); mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2);
mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3); auto l4 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3);
mm->add_return({l4});
auto prog = parse_tf("split_test.pb", false); auto prog = parse_tf("split_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
...@@ -770,8 +798,8 @@ TEST_CASE(split_test_one_output) ...@@ -770,8 +798,8 @@ TEST_CASE(split_test_one_output)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
mm->add_literal(1); // num_splits mm->add_literal(1); // num_splits
mm->add_literal(1); // split axis mm->add_literal(1); // split axis
mm->add_instruction(migraphx::make_op("identity"), l0); auto l1 = mm->add_instruction(migraphx::make_op("identity"), l0);
mm->add_return({l1});
auto prog = parse_tf("split_test_one_output.pb", false); auto prog = parse_tf("split_test_one_output.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
...@@ -797,8 +825,8 @@ TEST_CASE(split_test_vector_as_input) ...@@ -797,8 +825,8 @@ TEST_CASE(split_test_vector_as_input)
auto l3 = mm->add_instruction( auto l3 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 19}}, {"ends", {5, 30}}}), l0); migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 19}}, {"ends", {5, 30}}}), l0);
mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2); mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2);
mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3); auto l4 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3);
mm->add_return({l4});
auto prog = parse_tf("split_test_vector_as_input.pb", false); auto prog = parse_tf("split_test_vector_as_input.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
...@@ -884,7 +912,8 @@ TEST_CASE(stridedslice_masks_test) ...@@ -884,7 +912,8 @@ TEST_CASE(stridedslice_masks_test)
auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0); auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), l0);
auto l2 = mm->add_instruction(op, l1); auto l2 = mm->add_instruction(op, l1);
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 3, 1, 2}}}), l2); auto l3 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 3, 1, 2}}}), l2);
mm->add_return({l3});
auto prog = parse_tf("stridedslice_masks_test.pb", true); auto prog = parse_tf("stridedslice_masks_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
...@@ -897,7 +926,8 @@ TEST_CASE(sub_test) ...@@ -897,7 +926,8 @@ TEST_CASE(sub_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
mm->add_instruction(migraphx::make_op("sub"), l0, l1); auto l2 = mm->add_instruction(migraphx::make_op("sub"), l0, l1);
mm->add_return({l2});
auto prog = parse_tf("sub_test.pb", false); auto prog = parse_tf("sub_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
...@@ -908,10 +938,10 @@ TEST_CASE(tanh_test) ...@@ -908,10 +938,10 @@ TEST_CASE(tanh_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); auto l1 = mm->add_instruction(migraphx::make_op("tanh"), l0);
mm->add_instruction(migraphx::make_op("sub"), l0, l1); mm->add_return({l1});
auto prog = parse_tf("sub_test.pb", false); auto prog = parse_tf("tanh_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment