Unverified Commit 197b3a46 authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Infer outputs in tf (#764)



* fix relu6

* add more transposes

* add multi output

* formatting

* add tests

* formatting

* fix tests

* change to_nchw for outputs

* add python api

* fix cppcheck

* remove variable

* fix lambda

* add multi_output test

* add more tests and merge

* fix help message

* debugging work

* fix valid op string

* formatting

* manual merge

* mark function as const
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
Co-authored-by: default avatarShucai Xiao <shucai@gmail.com>
parent d43cf0a3
...@@ -7,13 +7,9 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -7,13 +7,9 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void unary_not(hipStream_t stream, void unary_not(hipStream_t stream, const argument& result, const argument& arg)
const argument& result,
const argument& arg
)
{ {
nary(stream, result, arg)( nary(stream, result, arg)([](auto x) __device__ { return not x; });
[](auto x) __device__ { return not x; });
} }
} // namespace device } // namespace device
......
...@@ -10,9 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,9 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void unary_not(hipStream_t stream, void unary_not(hipStream_t stream, const argument& result, const argument& arg);
const argument& result,
const argument& arg);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -139,7 +139,6 @@ struct miopen_apply ...@@ -139,7 +139,6 @@ struct miopen_apply
add_generic_op("tan"); add_generic_op("tan");
add_generic_op("tanh"); add_generic_op("tanh");
add_extend_op("abs"); add_extend_op("abs");
add_extend_op("argmax"); add_extend_op("argmax");
add_extend_op("argmin"); add_extend_op("argmin");
......
...@@ -96,6 +96,7 @@ struct tf_parser ...@@ -96,6 +96,7 @@ struct tf_parser
void parse_node(const std::string& name); void parse_node(const std::string& name);
literal parse_tensor(const tensorflow::TensorProto& t) const; literal parse_tensor(const tensorflow::TensorProto& t) const;
shape::type_t parse_type(tensorflow::DataType t) const; shape::type_t parse_type(tensorflow::DataType t) const;
std::vector<std::string> find_outputs() const;
}; };
std::vector<int64_t> get_axes_from_mask(size_t num_axes, uint32_t mask); std::vector<int64_t> get_axes_from_mask(size_t num_axes, uint32_t mask);
......
...@@ -215,6 +215,7 @@ static tf_parser::attribute_map get_attributes(const tensorflow::NodeDef& node) ...@@ -215,6 +215,7 @@ static tf_parser::attribute_map get_attributes(const tensorflow::NodeDef& node)
{ {
result[attr.first] = attr.second; result[attr.first] = attr.second;
} }
return result; return result;
} }
...@@ -254,6 +255,45 @@ create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, std::v ...@@ -254,6 +255,45 @@ create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, std::v
return literal{{shape_type, dims}, data}; return literal{{shape_type, dims}, data};
} }
static bool is_valid_op(const tensorflow::NodeDef& node)
{
std::vector<std::string> ignored{"NoOp", "Assert"};
for(const auto& op : ignored)
{
const auto& name = get_name(node);
if(node.op() == op or contains(name, op))
return false;
}
return true;
}
std::vector<std::string> tf_parser::find_outputs() const
{
std::unordered_set<std::string> inputs;
for(auto&& p : nodes)
{
auto&& node = p.second;
std::copy(node.input().begin(), node.input().end(), std::inserter(inputs, inputs.end()));
}
std::vector<std::string> outputs;
for(auto&& p : nodes)
{
const auto& name = p.first;
const auto& node = p.second;
if(not is_valid_op(node))
continue;
// control flow related, ignore this node
if(contains(name, "^"))
continue;
// literals are valid ops, but they are not outputs unless specified
if(node.op() == "Const")
continue;
if(inputs.count(name) == 0)
outputs.push_back(name);
}
return outputs;
}
void tf_parser::parse_graph(const tensorflow::GraphDef& graph) void tf_parser::parse_graph(const tensorflow::GraphDef& graph)
{ {
nodes = get_nodes(graph, input_nodes); nodes = get_nodes(graph, input_nodes);
...@@ -293,22 +333,20 @@ void tf_parser::parse_graph(const tensorflow::GraphDef& graph) ...@@ -293,22 +333,20 @@ void tf_parser::parse_graph(const tensorflow::GraphDef& graph)
// the program // the program
if(output_node_names.empty()) if(output_node_names.empty())
{ {
mm->add_return({to_nchw(last_ins)}); output_node_names = find_outputs();
}
else
{
std::vector<instruction_ref> output_ins;
std::transform(output_node_names.begin(),
output_node_names.end(),
std::back_inserter(output_ins),
[&](auto output_name) {
if(not contains(instructions, output_name))
MIGRAPHX_THROW("PARSE_TF: output name " + output_name +
" not found in graph!");
return this->to_nchw(instructions[output_name]);
});
mm->add_return(output_ins);
} }
std::vector<instruction_ref> output_ins;
std::transform(output_node_names.begin(),
output_node_names.end(),
std::back_inserter(output_ins),
[&](auto output_name) {
if(not contains(instructions, output_name))
MIGRAPHX_THROW("PARSE_TF: output name " + output_name +
" not found in graph!");
return this->to_nchw(instructions[output_name]);
});
mm->add_return(output_ins);
} }
} }
...@@ -317,12 +355,9 @@ void tf_parser::parse_node(const std::string& name) ...@@ -317,12 +355,9 @@ void tf_parser::parse_node(const std::string& name)
if(instructions.count(name) == 0) if(instructions.count(name) == 0)
{ {
auto&& node = nodes.at(name); auto&& node = nodes.at(name);
// assert ops ignored if(not is_valid_op(node))
if(node.op() == "Assert" or contains(name, "Assert"))
return;
// noOps ignored
if(node.op() == "NoOp" or contains(name, "NoOp"))
return; return;
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
for(auto&& input : node.input()) for(auto&& input : node.input())
...@@ -351,7 +386,6 @@ void tf_parser::parse_node(const std::string& name) ...@@ -351,7 +386,6 @@ void tf_parser::parse_node(const std::string& name)
args.push_back(instructions.at(input)); args.push_back(instructions.at(input));
} }
} }
std::vector<instruction_ref> result; std::vector<instruction_ref> result;
if(ops.count(node.op()) == 0) if(ops.count(node.op()) == 0)
{ {
...@@ -361,7 +395,6 @@ void tf_parser::parse_node(const std::string& name) ...@@ -361,7 +395,6 @@ void tf_parser::parse_node(const std::string& name)
{ {
result = ops[node.op()](*this, {get_attributes(node), node.op(), mm}, args); result = ops[node.op()](*this, {get_attributes(node), node.op(), mm}, args);
} }
assert(!result.empty()); assert(!result.empty());
// First output has no ":" delimiter // First output has no ":" delimiter
instructions[name] = result.front(); instructions[name] = result.front();
......
...@@ -1684,7 +1684,6 @@ TEST_CASE(logical_or_test) ...@@ -1684,7 +1684,6 @@ TEST_CASE(logical_or_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(logical_xor_bcast_test) TEST_CASE(logical_xor_bcast_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -1840,7 +1840,7 @@ TEST_CASE(logical_and_test) ...@@ -1840,7 +1840,7 @@ TEST_CASE(logical_and_test)
TEST_CASE(not_test) TEST_CASE(not_test)
{ {
//int32 // int32
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
...@@ -1855,7 +1855,7 @@ TEST_CASE(not_test) ...@@ -1855,7 +1855,7 @@ TEST_CASE(not_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
//bool // bool
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
...@@ -1869,7 +1869,6 @@ TEST_CASE(not_test) ...@@ -1869,7 +1869,6 @@ TEST_CASE(not_test)
std::vector<char> gold = {1, 1, 0, 0}; std::vector<char> gold = {1, 1, 0, 0};
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
} }
TEST_CASE(logical_or_test) TEST_CASE(logical_or_test)
......
...@@ -336,7 +336,6 @@ def mean_test_nhwc(g1): ...@@ -336,7 +336,6 @@ def mean_test_nhwc(g1):
g1_input = tf.compat.v1.placeholder(tf.float32, g1_input = tf.compat.v1.placeholder(tf.float32,
shape=(1, 16, 16, 3), shape=(1, 16, 16, 3),
name='0') name='0')
tf.math.reduce_mean(g1_input, axis=(1, 2), keepdims=True, name='mean1')
tf.math.reduce_mean(g1_input, tf.math.reduce_mean(g1_input,
axis=(1, 2), axis=(1, 2),
keepdims=False, keepdims=False,
......
...@@ -39,10 +39,12 @@ migraphx::program optimize_tf(const std::string& name, bool is_nhwc) ...@@ -39,10 +39,12 @@ migraphx::program optimize_tf(const std::string& name, bool is_nhwc)
// remove the last return instruction // remove the last return instruction
auto last_ins = std::prev(mm->end()); auto last_ins = std::prev(mm->end());
if(last_ins != mm->end()) if(last_ins != mm->end())
{
if(last_ins->name() == "@return") if(last_ins->name() == "@return")
{ {
mm->remove_instruction(last_ins); mm->remove_instruction(last_ins);
} }
}
return prog; return prog;
} }
...@@ -641,6 +643,7 @@ TEST_CASE(pooling_test) ...@@ -641,6 +643,7 @@ TEST_CASE(pooling_test)
max_pool_op.stride = {2, 2}; max_pool_op.stride = {2, 2};
avg_pool_op.lengths = {2, 2}; avg_pool_op.lengths = {2, 2};
max_pool_op.lengths = {2, 2}; max_pool_op.lengths = {2, 2};
mm->add_instruction(avg_pool_op, l0);
mm->add_instruction(max_pool_op, l0); mm->add_instruction(max_pool_op, l0);
auto prog = optimize_tf("pooling_test.pb", true); auto prog = optimize_tf("pooling_test.pb", true);
...@@ -782,9 +785,9 @@ TEST_CASE(split_test) ...@@ -782,9 +785,9 @@ TEST_CASE(split_test)
migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 10}}, {"ends", {5, 20}}}), l0); migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 10}}, {"ends", {5, 20}}}), l0);
auto l3 = mm->add_instruction( auto l3 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 20}}, {"ends", {5, 30}}}), l0); migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 20}}, {"ends", {5, 30}}}), l0);
mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2); auto l4 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2);
auto l4 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3); auto l5 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3);
mm->add_return({l4}); mm->add_return({l4, l5});
auto prog = parse_tf("split_test.pb", false); auto prog = parse_tf("split_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
...@@ -824,9 +827,9 @@ TEST_CASE(split_test_vector_as_input) ...@@ -824,9 +827,9 @@ TEST_CASE(split_test_vector_as_input)
migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 4}}, {"ends", {5, 19}}}), l0); migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 4}}, {"ends", {5, 19}}}), l0);
auto l3 = mm->add_instruction( auto l3 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 19}}, {"ends", {5, 30}}}), l0); migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 19}}, {"ends", {5, 30}}}), l0);
mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2); auto l4 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2);
auto l4 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3); auto l5 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3);
mm->add_return({l4}); mm->add_return({l4, l5});
auto prog = parse_tf("split_test_vector_as_input.pb", false); auto prog = parse_tf("split_test_vector_as_input.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment