Unverified Commit 867539b7 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

If const cond (#739)



* if operator support with constant condition input

* clang format

* add a missing file

* clang format

* add an onnx verifcation unit test for the if operator

* clang format'

* fix review comments

* temp version to try jenkin build

* remove unnecessary changes

* unit tests refinement for more code coverage

* clang format

* try a mutex to fix possible race condition in onnxruntime tests

* tmp changes to try jenkins build

* remove unnecessary code

* fix review comments
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 8640d392
......@@ -27,7 +27,7 @@ struct onnx_parser
attribute_map attributes{};
std::size_t num_outputs = 1;
std::string name = "";
module* mm = nullptr;
module* mod = nullptr;
instruction_ref make_contiguous(instruction_ref ins) const;
instruction_ref add_bias(const std::vector<instruction_ref>& args,
instruction_ref curr_ins,
......@@ -52,7 +52,7 @@ struct onnx_parser
};
using node_map = std::unordered_map<std::string, onnx::NodeProto>;
using op_func = std::function<std::vector<instruction_ref>(
const onnx_parser&, const node_info&, std::vector<instruction_ref>)>;
onnx_parser&, const node_info&, std::vector<instruction_ref>)>;
node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions;
program prog = program();
......@@ -65,11 +65,11 @@ struct onnx_parser
onnx_parser();
operation load(const std::string& name, const node_info& info) const;
void parse_undefined(module* mm, const std::string& name);
void parse_undefined(module* mod, const std::string& name);
void parse_from(std::istream& is, std::string name = "");
void parse_from(const void* data, std::size_t size);
void parse_graph(const onnx::GraphProto& graph);
void parse_graph(module* mod, const onnx::GraphProto& graph);
literal parse_value(const onnx::AttributeProto& attr) const;
literal parse_tensor(const onnx::TensorProto& t) const;
shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims) const;
......
......@@ -66,10 +66,10 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r
{
if(args.size() == 3)
{
auto bias_bcast = mm->add_instruction(
auto bias_bcast = mod->add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", curr_ins->get_shape().lens()}}),
args[2]);
return mm->add_instruction(make_op("add"), curr_ins, bias_bcast);
return mod->add_instruction(make_op("add"), curr_ins, bias_bcast);
}
return curr_ins;
}
......@@ -140,12 +140,12 @@ instruction_ref
onnx_parser::node_info::add_instruction(const operation& op,
const std::vector<instruction_ref>& args) const
{
return mm->add_instruction(op, args);
return mod->add_instruction(op, args);
}
instruction_ref onnx_parser::node_info::add_literal(literal l) const
{
return mm->add_literal(std::move(l));
return mod->add_literal(std::move(l));
}
onnx_parser::onnx_parser()
......@@ -183,17 +183,18 @@ operation onnx_parser::load(const std::string& name, const node_info& info) cons
return op;
}
void onnx_parser::parse_undefined(module* mm, const std::string& name)
void onnx_parser::parse_undefined(module* mod, const std::string& name)
{
if(!contains(instructions, name))
{
auto ins = mm->add_instruction(make_op("undefined"));
auto ins = mod->add_instruction(make_op("undefined"));
instructions[name] = ins;
}
}
void onnx_parser::parse_from(std::istream& is, std::string name)
{
auto* mm = prog.get_main_module();
this->filename = std::move(name);
auto parent_path = fs::path(this->filename).parent_path();
if(not parent_path.empty())
......@@ -204,23 +205,24 @@ void onnx_parser::parse_from(std::istream& is, std::string name)
{
if(model.has_graph())
{
this->parse_graph(model.graph());
this->parse_graph(mm, model.graph());
}
}
else
{
MIGRAPHX_THROW("Failed reading onnx file.");
MIGRAPHX_THROW("PARSE_FROM: Failed reading onnx file: " + this->filename);
}
}
void onnx_parser::parse_from(const void* data, std::size_t size)
{
auto* mm = prog.get_main_module();
onnx::ModelProto model;
if(model.ParseFromArray(data, size))
{
if(model.has_graph())
{
this->parse_graph(model.graph());
this->parse_graph(mm, model.graph());
}
}
else
......@@ -229,12 +231,11 @@ void onnx_parser::parse_from(const void* data, std::size_t size)
}
}
void onnx_parser::parse_graph(const onnx::GraphProto& graph)
void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
{
module* mm = prog.get_main_module();
for(auto&& f : graph.initializer())
{
instructions[f.name()] = mm->add_literal(parse_tensor(f));
instructions[f.name()] = mod->add_literal(parse_tensor(f));
}
for(auto&& input : graph.input())
......@@ -250,7 +251,7 @@ void onnx_parser::parse_graph(const onnx::GraphProto& graph)
}
shape s = parse_type(input.type(), dims);
instructions[name] = mm->add_parameter(name, s);
instructions[name] = mod->add_parameter(name, s);
}
}
......@@ -261,7 +262,7 @@ void onnx_parser::parse_graph(const onnx::GraphProto& graph)
{
if(input.empty())
{
this->parse_undefined(mm, input);
this->parse_undefined(mod, input);
}
if(instructions.count(input) == 0)
{
......@@ -276,14 +277,14 @@ void onnx_parser::parse_graph(const onnx::GraphProto& graph)
if(ops.count(node.op_type()) == 0)
{
if(skip_unknown_operators)
result.push_back(mm->add_instruction(op::unknown{node.op_type()}, args));
result.push_back(mod->add_instruction(op::unknown{node.op_type()}, args));
else
MIGRAPHX_THROW("Unknown operator: " + node.op_type());
}
else
{
result = ops[node.op_type()](
*this, {get_attributes(node), output_num, node.op_type(), mm}, args);
*this, {get_attributes(node), output_num, node.op_type(), mod}, args);
}
output_num = std::min<std::size_t>(output_num, result.size());
......@@ -315,7 +316,7 @@ void onnx_parser::parse_graph(const onnx::GraphProto& graph)
[&](const auto& name) { return instructions[name]; });
// add the return instuction
mm->add_return(output_ins);
mod->add_return(output_ins);
}
literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_if : op_parser<parse_if>
{
std::vector<op_desc> operators() const { return {{"If"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
migraphx::argument cond_arg = args.front()->eval();
// cond is not constant, need to create sub_modules
if(cond_arg.empty())
{
MIGRAPHX_THROW(
"PARSE_IF: current implementation requires condition input to be constant!");
}
if(cond_arg.get_shape().elements() != 1)
{
MIGRAPHX_THROW("PARSE_IF: condition input can have only one element!");
}
auto* mod = info.mod;
// then branch
if(cond_arg.at<bool>())
{
const auto& then_graph = info.attributes.at("then_branch").g();
parser.parse_graph(mod, then_graph);
}
// else branch
else
{
const auto& else_graph = info.attributes.at("else_branch").g();
parser.parse_graph(mod, else_graph);
}
// inputs of the return instruction are that of the output of the
// if instruction
instruction_ref ret_ins = std::prev(mod->end());
auto outputs = ret_ins->inputs();
assert(ret_ins->name() == "@return");
mod->remove_instruction(ret_ins);
return outputs;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -1519,6 +1519,114 @@ def group_conv_test():
return ([node], [x, y], [z])
@onnx_test
def if_else_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 3])
then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT,
[2, 3])
else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT,
[2, 3])
xt = np.ones((2, 3)).astype(np.float)
xt_tensor = helper.make_tensor(name='xt',
data_type=TensorProto.FLOAT,
dims=xt.shape,
vals=xt.flatten().astype(np.float32))
yt = np.random.randn(2, 3).astype(np.float)
yt_tensor = helper.make_tensor(name='yt',
data_type=TensorProto.FLOAT,
dims=yt.shape,
vals=yt.flatten().astype(np.float32))
then_add_node = onnx.helper.make_node('Add',
inputs=['x', 'xt'],
outputs=['then_out'])
else_mul_node = onnx.helper.make_node('Mul',
inputs=['y', 'yt'],
outputs=['else_out'])
then_body = onnx.helper.make_graph([then_add_node], 'then_body', [],
[then_out])
else_body = onnx.helper.make_graph([else_mul_node], 'else_body', [],
[else_out])
cond = np.array([0]).astype(np.bool)
cond_tensor = helper.make_tensor(name="cond",
data_type=TensorProto.BOOL,
dims=cond.shape,
vals=cond.astype(bool))
res = onnx.helper.make_tensor_value_info('res', TensorProto.FLOAT, [])
node = onnx.helper.make_node('If',
inputs=['cond'],
outputs=['res'],
then_branch=then_body,
else_branch=else_body)
return ([node], [x, y], [res], [cond_tensor, xt_tensor, yt_tensor])
@onnx_test
def if_then_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 3])
then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT,
[2, 3])
else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT,
[2, 3])
xt = np.ones((2, 3)).astype(np.float)
xt_tensor = helper.make_tensor(name='xt',
data_type=TensorProto.FLOAT,
dims=xt.shape,
vals=xt.flatten().astype(np.float32))
yt = np.random.randn(2, 3).astype(np.float)
yt_tensor = helper.make_tensor(name='yt',
data_type=TensorProto.FLOAT,
dims=yt.shape,
vals=yt.flatten().astype(np.float32))
then_add_node = onnx.helper.make_node('Add',
inputs=['x', 'xt'],
outputs=['then_out'])
else_mul_node = onnx.helper.make_node('Mul',
inputs=['y', 'yt'],
outputs=['else_out'])
then_body = onnx.helper.make_graph([then_add_node], 'then_body', [],
[then_out])
else_body = onnx.helper.make_graph([else_mul_node], 'else_body', [],
[else_out])
cond = np.array([1]).astype(np.bool)
cond_tensor = helper.make_tensor(name="cond",
data_type=TensorProto.BOOL,
dims=cond.shape,
vals=cond.astype(bool))
res = onnx.helper.make_tensor_value_info('res', TensorProto.FLOAT, [])
node = onnx.helper.make_node('If',
inputs=['cond'],
outputs=['res'],
then_branch=then_body,
else_branch=else_body)
return ([node], [x, y], [res], [cond_tensor, xt_tensor, yt_tensor])
@onnx_test
def imagescaler_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 16, 16])
......
#include <iostream>
#include <fstream>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
......@@ -1376,6 +1377,60 @@ TEST_CASE(group_conv_test)
EXPECT(p == prog);
}
TEST_CASE(if_else_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
mm->add_literal(migraphx::literal(sc, {0}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f);
mm->add_literal(s, ones);
std::vector<float> rand = {-0.583375, 0.633757, 0.0668345, -0.479422, -0.604634, 0.0388589};
auto l2 = mm->add_literal(s, rand);
mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto r = mm->add_instruction(migraphx::make_op("mul"), y, l2);
mm->add_return({r});
std::ifstream ifs("if_else_test.onnx", std::ios::binary);
ifs.seekg(0, std::ios::end);
auto length = ifs.tellg();
ifs.seekg(0, std::ios::beg);
std::vector<char> onnx_buffer(length);
ifs.read(onnx_buffer.data(), length);
ifs.close();
auto prog = migraphx::parse_onnx_buffer(onnx_buffer.data(), length, {});
EXPECT(p == prog);
}
TEST_CASE(if_then_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
mm->add_literal(migraphx::literal(sc, {1}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s, ones);
std::vector<float> rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946};
mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s);
mm->add_parameter("y", s);
auto r = mm->add_instruction(migraphx::make_op("add"), x, l1);
mm->add_return({r});
auto prog = migraphx::parse_onnx("if_then_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(imagescaler_test)
{
migraphx::program p;
......
......@@ -68,6 +68,25 @@ TEST_CASE(gather_elements)
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(if_else_test)
{
migraphx::program p = migraphx::parse_onnx("if_else_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s_data{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625};
migraphx::parameter_map pp;
pp["y"] = migraphx::argument(s_data, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {
-0.0364609435, 0.475317657, -0.00417715637, -0.0599277429, 0.0755792186, -0.0218581557};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(instance_norm_test)
{
migraphx::program p = migraphx::parse_onnx("instance_norm_val_test.onnx");
......
......@@ -89,6 +89,7 @@ def create_backend_test(testname=None, target_device=None):
backend_test.include(r'.*test_greater.*')
backend_test.include(r'.*test_hardsigmoid.*')
backend_test.include(r'.*test_identity.*')
backend_test.include(r'.*test_if.*')
backend_test.include(r'.*test_LeakyReLU*')
backend_test.include(r'.*test_leakyrelu.*')
backend_test.include(r'.*test_less.*')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment