Unverified Commit 867539b7 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

If const cond (#739)



* if operator support with constant condition input

* clang format

* add a missing file

* clang format

* add an onnx verifcation unit test for the if operator

* clang format'

* fix review comments

* temp version to try jenkin build

* remove unnecessary changes

* unit tests refinement for more code coverage

* clang format

* try a mutex to fix possible race condition in onnxruntime tests

* tmp changes to try jenkins build

* remove unnecessary code

* fix review comments
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 8640d392
...@@ -27,7 +27,7 @@ struct onnx_parser ...@@ -27,7 +27,7 @@ struct onnx_parser
attribute_map attributes{}; attribute_map attributes{};
std::size_t num_outputs = 1; std::size_t num_outputs = 1;
std::string name = ""; std::string name = "";
module* mm = nullptr; module* mod = nullptr;
instruction_ref make_contiguous(instruction_ref ins) const; instruction_ref make_contiguous(instruction_ref ins) const;
instruction_ref add_bias(const std::vector<instruction_ref>& args, instruction_ref add_bias(const std::vector<instruction_ref>& args,
instruction_ref curr_ins, instruction_ref curr_ins,
...@@ -52,7 +52,7 @@ struct onnx_parser ...@@ -52,7 +52,7 @@ struct onnx_parser
}; };
using node_map = std::unordered_map<std::string, onnx::NodeProto>; using node_map = std::unordered_map<std::string, onnx::NodeProto>;
using op_func = std::function<std::vector<instruction_ref>( using op_func = std::function<std::vector<instruction_ref>(
const onnx_parser&, const node_info&, std::vector<instruction_ref>)>; onnx_parser&, const node_info&, std::vector<instruction_ref>)>;
node_map nodes; node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
program prog = program(); program prog = program();
...@@ -65,11 +65,11 @@ struct onnx_parser ...@@ -65,11 +65,11 @@ struct onnx_parser
onnx_parser(); onnx_parser();
operation load(const std::string& name, const node_info& info) const; operation load(const std::string& name, const node_info& info) const;
void parse_undefined(module* mm, const std::string& name); void parse_undefined(module* mod, const std::string& name);
void parse_from(std::istream& is, std::string name = ""); void parse_from(std::istream& is, std::string name = "");
void parse_from(const void* data, std::size_t size); void parse_from(const void* data, std::size_t size);
void parse_graph(const onnx::GraphProto& graph); void parse_graph(module* mod, const onnx::GraphProto& graph);
literal parse_value(const onnx::AttributeProto& attr) const; literal parse_value(const onnx::AttributeProto& attr) const;
literal parse_tensor(const onnx::TensorProto& t) const; literal parse_tensor(const onnx::TensorProto& t) const;
shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims) const; shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims) const;
......
...@@ -66,10 +66,10 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r ...@@ -66,10 +66,10 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r
{ {
if(args.size() == 3) if(args.size() == 3)
{ {
auto bias_bcast = mm->add_instruction( auto bias_bcast = mod->add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", curr_ins->get_shape().lens()}}), make_op("broadcast", {{"axis", axis}, {"dims", curr_ins->get_shape().lens()}}),
args[2]); args[2]);
return mm->add_instruction(make_op("add"), curr_ins, bias_bcast); return mod->add_instruction(make_op("add"), curr_ins, bias_bcast);
} }
return curr_ins; return curr_ins;
} }
...@@ -140,12 +140,12 @@ instruction_ref ...@@ -140,12 +140,12 @@ instruction_ref
onnx_parser::node_info::add_instruction(const operation& op, onnx_parser::node_info::add_instruction(const operation& op,
const std::vector<instruction_ref>& args) const const std::vector<instruction_ref>& args) const
{ {
return mm->add_instruction(op, args); return mod->add_instruction(op, args);
} }
instruction_ref onnx_parser::node_info::add_literal(literal l) const instruction_ref onnx_parser::node_info::add_literal(literal l) const
{ {
return mm->add_literal(std::move(l)); return mod->add_literal(std::move(l));
} }
onnx_parser::onnx_parser() onnx_parser::onnx_parser()
...@@ -183,17 +183,18 @@ operation onnx_parser::load(const std::string& name, const node_info& info) cons ...@@ -183,17 +183,18 @@ operation onnx_parser::load(const std::string& name, const node_info& info) cons
return op; return op;
} }
void onnx_parser::parse_undefined(module* mm, const std::string& name) void onnx_parser::parse_undefined(module* mod, const std::string& name)
{ {
if(!contains(instructions, name)) if(!contains(instructions, name))
{ {
auto ins = mm->add_instruction(make_op("undefined")); auto ins = mod->add_instruction(make_op("undefined"));
instructions[name] = ins; instructions[name] = ins;
} }
} }
void onnx_parser::parse_from(std::istream& is, std::string name) void onnx_parser::parse_from(std::istream& is, std::string name)
{ {
auto* mm = prog.get_main_module();
this->filename = std::move(name); this->filename = std::move(name);
auto parent_path = fs::path(this->filename).parent_path(); auto parent_path = fs::path(this->filename).parent_path();
if(not parent_path.empty()) if(not parent_path.empty())
...@@ -204,23 +205,24 @@ void onnx_parser::parse_from(std::istream& is, std::string name) ...@@ -204,23 +205,24 @@ void onnx_parser::parse_from(std::istream& is, std::string name)
{ {
if(model.has_graph()) if(model.has_graph())
{ {
this->parse_graph(model.graph()); this->parse_graph(mm, model.graph());
} }
} }
else else
{ {
MIGRAPHX_THROW("Failed reading onnx file."); MIGRAPHX_THROW("PARSE_FROM: Failed reading onnx file: " + this->filename);
} }
} }
void onnx_parser::parse_from(const void* data, std::size_t size) void onnx_parser::parse_from(const void* data, std::size_t size)
{ {
auto* mm = prog.get_main_module();
onnx::ModelProto model; onnx::ModelProto model;
if(model.ParseFromArray(data, size)) if(model.ParseFromArray(data, size))
{ {
if(model.has_graph()) if(model.has_graph())
{ {
this->parse_graph(model.graph()); this->parse_graph(mm, model.graph());
} }
} }
else else
...@@ -229,12 +231,11 @@ void onnx_parser::parse_from(const void* data, std::size_t size) ...@@ -229,12 +231,11 @@ void onnx_parser::parse_from(const void* data, std::size_t size)
} }
} }
void onnx_parser::parse_graph(const onnx::GraphProto& graph) void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
{ {
module* mm = prog.get_main_module();
for(auto&& f : graph.initializer()) for(auto&& f : graph.initializer())
{ {
instructions[f.name()] = mm->add_literal(parse_tensor(f)); instructions[f.name()] = mod->add_literal(parse_tensor(f));
} }
for(auto&& input : graph.input()) for(auto&& input : graph.input())
...@@ -250,7 +251,7 @@ void onnx_parser::parse_graph(const onnx::GraphProto& graph) ...@@ -250,7 +251,7 @@ void onnx_parser::parse_graph(const onnx::GraphProto& graph)
} }
shape s = parse_type(input.type(), dims); shape s = parse_type(input.type(), dims);
instructions[name] = mm->add_parameter(name, s); instructions[name] = mod->add_parameter(name, s);
} }
} }
...@@ -261,7 +262,7 @@ void onnx_parser::parse_graph(const onnx::GraphProto& graph) ...@@ -261,7 +262,7 @@ void onnx_parser::parse_graph(const onnx::GraphProto& graph)
{ {
if(input.empty()) if(input.empty())
{ {
this->parse_undefined(mm, input); this->parse_undefined(mod, input);
} }
if(instructions.count(input) == 0) if(instructions.count(input) == 0)
{ {
...@@ -276,14 +277,14 @@ void onnx_parser::parse_graph(const onnx::GraphProto& graph) ...@@ -276,14 +277,14 @@ void onnx_parser::parse_graph(const onnx::GraphProto& graph)
if(ops.count(node.op_type()) == 0) if(ops.count(node.op_type()) == 0)
{ {
if(skip_unknown_operators) if(skip_unknown_operators)
result.push_back(mm->add_instruction(op::unknown{node.op_type()}, args)); result.push_back(mod->add_instruction(op::unknown{node.op_type()}, args));
else else
MIGRAPHX_THROW("Unknown operator: " + node.op_type()); MIGRAPHX_THROW("Unknown operator: " + node.op_type());
} }
else else
{ {
result = ops[node.op_type()]( result = ops[node.op_type()](
*this, {get_attributes(node), output_num, node.op_type(), mm}, args); *this, {get_attributes(node), output_num, node.op_type(), mod}, args);
} }
output_num = std::min<std::size_t>(output_num, result.size()); output_num = std::min<std::size_t>(output_num, result.size());
...@@ -315,7 +316,7 @@ void onnx_parser::parse_graph(const onnx::GraphProto& graph) ...@@ -315,7 +316,7 @@ void onnx_parser::parse_graph(const onnx::GraphProto& graph)
[&](const auto& name) { return instructions[name]; }); [&](const auto& name) { return instructions[name]; });
// add the return instuction // add the return instuction
mm->add_return(output_ins); mod->add_return(output_ins);
} }
literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_if : op_parser<parse_if>
{
std::vector<op_desc> operators() const { return {{"If"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
migraphx::argument cond_arg = args.front()->eval();
// cond is not constant, need to create sub_modules
if(cond_arg.empty())
{
MIGRAPHX_THROW(
"PARSE_IF: current implementation requires condition input to be constant!");
}
if(cond_arg.get_shape().elements() != 1)
{
MIGRAPHX_THROW("PARSE_IF: condition input can have only one element!");
}
auto* mod = info.mod;
// then branch
if(cond_arg.at<bool>())
{
const auto& then_graph = info.attributes.at("then_branch").g();
parser.parse_graph(mod, then_graph);
}
// else branch
else
{
const auto& else_graph = info.attributes.at("else_branch").g();
parser.parse_graph(mod, else_graph);
}
// inputs of the return instruction are that of the output of the
// if instruction
instruction_ref ret_ins = std::prev(mod->end());
auto outputs = ret_ins->inputs();
assert(ret_ins->name() == "@return");
mod->remove_instruction(ret_ins);
return outputs;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -1519,6 +1519,114 @@ def group_conv_test(): ...@@ -1519,6 +1519,114 @@ def group_conv_test():
return ([node], [x, y], [z]) return ([node], [x, y], [z])
@onnx_test
def if_else_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 3])
then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT,
[2, 3])
else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT,
[2, 3])
xt = np.ones((2, 3)).astype(np.float)
xt_tensor = helper.make_tensor(name='xt',
data_type=TensorProto.FLOAT,
dims=xt.shape,
vals=xt.flatten().astype(np.float32))
yt = np.random.randn(2, 3).astype(np.float)
yt_tensor = helper.make_tensor(name='yt',
data_type=TensorProto.FLOAT,
dims=yt.shape,
vals=yt.flatten().astype(np.float32))
then_add_node = onnx.helper.make_node('Add',
inputs=['x', 'xt'],
outputs=['then_out'])
else_mul_node = onnx.helper.make_node('Mul',
inputs=['y', 'yt'],
outputs=['else_out'])
then_body = onnx.helper.make_graph([then_add_node], 'then_body', [],
[then_out])
else_body = onnx.helper.make_graph([else_mul_node], 'else_body', [],
[else_out])
cond = np.array([0]).astype(np.bool)
cond_tensor = helper.make_tensor(name="cond",
data_type=TensorProto.BOOL,
dims=cond.shape,
vals=cond.astype(bool))
res = onnx.helper.make_tensor_value_info('res', TensorProto.FLOAT, [])
node = onnx.helper.make_node('If',
inputs=['cond'],
outputs=['res'],
then_branch=then_body,
else_branch=else_body)
return ([node], [x, y], [res], [cond_tensor, xt_tensor, yt_tensor])
@onnx_test
def if_then_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 3])
then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT,
[2, 3])
else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT,
[2, 3])
xt = np.ones((2, 3)).astype(np.float)
xt_tensor = helper.make_tensor(name='xt',
data_type=TensorProto.FLOAT,
dims=xt.shape,
vals=xt.flatten().astype(np.float32))
yt = np.random.randn(2, 3).astype(np.float)
yt_tensor = helper.make_tensor(name='yt',
data_type=TensorProto.FLOAT,
dims=yt.shape,
vals=yt.flatten().astype(np.float32))
then_add_node = onnx.helper.make_node('Add',
inputs=['x', 'xt'],
outputs=['then_out'])
else_mul_node = onnx.helper.make_node('Mul',
inputs=['y', 'yt'],
outputs=['else_out'])
then_body = onnx.helper.make_graph([then_add_node], 'then_body', [],
[then_out])
else_body = onnx.helper.make_graph([else_mul_node], 'else_body', [],
[else_out])
cond = np.array([1]).astype(np.bool)
cond_tensor = helper.make_tensor(name="cond",
data_type=TensorProto.BOOL,
dims=cond.shape,
vals=cond.astype(bool))
res = onnx.helper.make_tensor_value_info('res', TensorProto.FLOAT, [])
node = onnx.helper.make_node('If',
inputs=['cond'],
outputs=['res'],
then_branch=then_body,
else_branch=else_body)
return ([node], [x, y], [res], [cond_tensor, xt_tensor, yt_tensor])
@onnx_test @onnx_test
def imagescaler_test(): def imagescaler_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 16, 16]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 16, 16])
......
#include <iostream> #include <iostream>
#include <fstream>
#include <vector> #include <vector>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
...@@ -1376,6 +1377,60 @@ TEST_CASE(group_conv_test) ...@@ -1376,6 +1377,60 @@ TEST_CASE(group_conv_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(if_else_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
mm->add_literal(migraphx::literal(sc, {0}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f);
mm->add_literal(s, ones);
std::vector<float> rand = {-0.583375, 0.633757, 0.0668345, -0.479422, -0.604634, 0.0388589};
auto l2 = mm->add_literal(s, rand);
mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto r = mm->add_instruction(migraphx::make_op("mul"), y, l2);
mm->add_return({r});
std::ifstream ifs("if_else_test.onnx", std::ios::binary);
ifs.seekg(0, std::ios::end);
auto length = ifs.tellg();
ifs.seekg(0, std::ios::beg);
std::vector<char> onnx_buffer(length);
ifs.read(onnx_buffer.data(), length);
ifs.close();
auto prog = migraphx::parse_onnx_buffer(onnx_buffer.data(), length, {});
EXPECT(p == prog);
}
TEST_CASE(if_then_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
mm->add_literal(migraphx::literal(sc, {1}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s, ones);
std::vector<float> rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946};
mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s);
mm->add_parameter("y", s);
auto r = mm->add_instruction(migraphx::make_op("add"), x, l1);
mm->add_return({r});
auto prog = migraphx::parse_onnx("if_then_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(imagescaler_test) TEST_CASE(imagescaler_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -68,6 +68,25 @@ TEST_CASE(gather_elements) ...@@ -68,6 +68,25 @@ TEST_CASE(gather_elements)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(if_else_test)
{
migraphx::program p = migraphx::parse_onnx("if_else_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s_data{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625};
migraphx::parameter_map pp;
pp["y"] = migraphx::argument(s_data, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {
-0.0364609435, 0.475317657, -0.00417715637, -0.0599277429, 0.0755792186, -0.0218581557};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(instance_norm_test) TEST_CASE(instance_norm_test)
{ {
migraphx::program p = migraphx::parse_onnx("instance_norm_val_test.onnx"); migraphx::program p = migraphx::parse_onnx("instance_norm_val_test.onnx");
......
...@@ -89,6 +89,7 @@ def create_backend_test(testname=None, target_device=None): ...@@ -89,6 +89,7 @@ def create_backend_test(testname=None, target_device=None):
backend_test.include(r'.*test_greater.*') backend_test.include(r'.*test_greater.*')
backend_test.include(r'.*test_hardsigmoid.*') backend_test.include(r'.*test_hardsigmoid.*')
backend_test.include(r'.*test_identity.*') backend_test.include(r'.*test_identity.*')
backend_test.include(r'.*test_if.*')
backend_test.include(r'.*test_LeakyReLU*') backend_test.include(r'.*test_LeakyReLU*')
backend_test.include(r'.*test_leakyrelu.*') backend_test.include(r'.*test_leakyrelu.*')
backend_test.include(r'.*test_less.*') backend_test.include(r'.*test_less.*')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment