Unverified Commit 197b3a46 authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Infer outputs in tf (#764)



* fix relu6

* add more transposes

* add multi output

* formatting

* add tests

* formatting

* fix tests

* change to_nchw for outputs

* add python api

* fix cppcheck

* remove variable

* fix lambda

* add multi_output test

* add more tests and merge

* fix help message

* debugging work

* fix valid op string

* formatting

* manual merge

* mark function as const
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
Co-authored-by: default avatarShucai Xiao <shucai@gmail.com>
parent d43cf0a3
......@@ -7,13 +7,9 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void unary_not(hipStream_t stream,
const argument& result,
const argument& arg
)
void unary_not(hipStream_t stream, const argument& result, const argument& arg)
{
nary(stream, result, arg)(
[](auto x) __device__ { return not x; });
nary(stream, result, arg)([](auto x) __device__ { return not x; });
}
} // namespace device
......
......@@ -10,9 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void unary_not(hipStream_t stream,
const argument& result,
const argument& arg);
void unary_not(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
......
......@@ -139,7 +139,6 @@ struct miopen_apply
add_generic_op("tan");
add_generic_op("tanh");
add_extend_op("abs");
add_extend_op("argmax");
add_extend_op("argmin");
......
......@@ -96,6 +96,7 @@ struct tf_parser
void parse_node(const std::string& name);
literal parse_tensor(const tensorflow::TensorProto& t) const;
shape::type_t parse_type(tensorflow::DataType t) const;
std::vector<std::string> find_outputs() const;
};
std::vector<int64_t> get_axes_from_mask(size_t num_axes, uint32_t mask);
......
......@@ -215,6 +215,7 @@ static tf_parser::attribute_map get_attributes(const tensorflow::NodeDef& node)
{
result[attr.first] = attr.second;
}
return result;
}
......@@ -254,6 +255,45 @@ create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, std::v
return literal{{shape_type, dims}, data};
}
static bool is_valid_op(const tensorflow::NodeDef& node)
{
std::vector<std::string> ignored{"NoOp", "Assert"};
for(const auto& op : ignored)
{
const auto& name = get_name(node);
if(node.op() == op or contains(name, op))
return false;
}
return true;
}
std::vector<std::string> tf_parser::find_outputs() const
{
std::unordered_set<std::string> inputs;
for(auto&& p : nodes)
{
auto&& node = p.second;
std::copy(node.input().begin(), node.input().end(), std::inserter(inputs, inputs.end()));
}
std::vector<std::string> outputs;
for(auto&& p : nodes)
{
const auto& name = p.first;
const auto& node = p.second;
if(not is_valid_op(node))
continue;
// control flow related, ignore this node
if(contains(name, "^"))
continue;
// literals are valid ops, but they are not outputs unless specified
if(node.op() == "Const")
continue;
if(inputs.count(name) == 0)
outputs.push_back(name);
}
return outputs;
}
void tf_parser::parse_graph(const tensorflow::GraphDef& graph)
{
nodes = get_nodes(graph, input_nodes);
......@@ -293,22 +333,20 @@ void tf_parser::parse_graph(const tensorflow::GraphDef& graph)
// the program
if(output_node_names.empty())
{
mm->add_return({to_nchw(last_ins)});
}
else
{
std::vector<instruction_ref> output_ins;
std::transform(output_node_names.begin(),
output_node_names.end(),
std::back_inserter(output_ins),
[&](auto output_name) {
if(not contains(instructions, output_name))
MIGRAPHX_THROW("PARSE_TF: output name " + output_name +
" not found in graph!");
return this->to_nchw(instructions[output_name]);
});
mm->add_return(output_ins);
output_node_names = find_outputs();
}
std::vector<instruction_ref> output_ins;
std::transform(output_node_names.begin(),
output_node_names.end(),
std::back_inserter(output_ins),
[&](auto output_name) {
if(not contains(instructions, output_name))
MIGRAPHX_THROW("PARSE_TF: output name " + output_name +
" not found in graph!");
return this->to_nchw(instructions[output_name]);
});
mm->add_return(output_ins);
}
}
......@@ -317,12 +355,9 @@ void tf_parser::parse_node(const std::string& name)
if(instructions.count(name) == 0)
{
auto&& node = nodes.at(name);
// assert ops ignored
if(node.op() == "Assert" or contains(name, "Assert"))
return;
// noOps ignored
if(node.op() == "NoOp" or contains(name, "NoOp"))
if(not is_valid_op(node))
return;
std::vector<instruction_ref> args;
for(auto&& input : node.input())
......@@ -351,7 +386,6 @@ void tf_parser::parse_node(const std::string& name)
args.push_back(instructions.at(input));
}
}
std::vector<instruction_ref> result;
if(ops.count(node.op()) == 0)
{
......@@ -361,7 +395,6 @@ void tf_parser::parse_node(const std::string& name)
{
result = ops[node.op()](*this, {get_attributes(node), node.op(), mm}, args);
}
assert(!result.empty());
// First output has no ":" delimiter
instructions[name] = result.front();
......
......@@ -1684,7 +1684,6 @@ TEST_CASE(logical_or_test)
EXPECT(p == prog);
}
TEST_CASE(logical_xor_bcast_test)
{
migraphx::program p;
......
......@@ -1840,7 +1840,7 @@ TEST_CASE(logical_and_test)
TEST_CASE(not_test)
{
//int32
// int32
{
migraphx::program p;
auto* mm = p.get_main_module();
......@@ -1855,7 +1855,7 @@ TEST_CASE(not_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}
//bool
// bool
{
migraphx::program p;
auto* mm = p.get_main_module();
......@@ -1869,7 +1869,6 @@ TEST_CASE(not_test)
std::vector<char> gold = {1, 1, 0, 0};
EXPECT(migraphx::verify_range(results_vector, gold));
}
}
TEST_CASE(logical_or_test)
......
......@@ -336,7 +336,6 @@ def mean_test_nhwc(g1):
g1_input = tf.compat.v1.placeholder(tf.float32,
shape=(1, 16, 16, 3),
name='0')
tf.math.reduce_mean(g1_input, axis=(1, 2), keepdims=True, name='mean1')
tf.math.reduce_mean(g1_input,
axis=(1, 2),
keepdims=False,
......
......@@ -39,10 +39,12 @@ migraphx::program optimize_tf(const std::string& name, bool is_nhwc)
// remove the last return instruction
auto last_ins = std::prev(mm->end());
if(last_ins != mm->end())
{
if(last_ins->name() == "@return")
{
mm->remove_instruction(last_ins);
}
}
return prog;
}
......@@ -641,6 +643,7 @@ TEST_CASE(pooling_test)
max_pool_op.stride = {2, 2};
avg_pool_op.lengths = {2, 2};
max_pool_op.lengths = {2, 2};
mm->add_instruction(avg_pool_op, l0);
mm->add_instruction(max_pool_op, l0);
auto prog = optimize_tf("pooling_test.pb", true);
......@@ -782,9 +785,9 @@ TEST_CASE(split_test)
migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 10}}, {"ends", {5, 20}}}), l0);
auto l3 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 20}}, {"ends", {5, 30}}}), l0);
mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2);
auto l4 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3);
mm->add_return({l4});
auto l4 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2);
auto l5 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3);
mm->add_return({l4, l5});
auto prog = parse_tf("split_test.pb", false);
EXPECT(p == prog);
......@@ -824,9 +827,9 @@ TEST_CASE(split_test_vector_as_input)
migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 4}}, {"ends", {5, 19}}}), l0);
auto l3 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 19}}, {"ends", {5, 30}}}), l0);
mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2);
auto l4 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3);
mm->add_return({l4});
auto l4 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2);
auto l5 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3);
mm->add_return({l4, l5});
auto prog = parse_tf("split_test_vector_as_input.pb", false);
EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment