"vscode:/vscode.git/clone" did not exist on "c98b22d88a00f1daeeabb00fa4ab98e2cd918420"
Unverified Commit 4fdc4dfe authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Where op (#630)



* add the where operator

* clang format

* add where unit tests

* add where op unit test

* clang format

* add more unit tests for the where op

* clang format

* Add support for constructing value from enum

* Formatting

* add an comment about the algorithm

* call make_op to create the convert instruction
Co-authored-by: default avatarPaul <pfultz2@yahoo.com>
parent 9f283810
......@@ -171,6 +171,7 @@ struct onnx_parser
add_mem_op("Split", &onnx_parser::parse_split);
add_mem_op("Tile", &onnx_parser::parse_tile);
add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("Where", &onnx_parser::parse_where);
// init the activation function map
init_actv_func();
......@@ -2165,7 +2166,7 @@ struct onnx_parser
int to_type = parse_value(info.attributes.at("to")).at<int>();
shape::type_t type = get_type(to_type);
return prog.add_instruction(op::convert{type}, std::move(args));
return prog.add_instruction(make_op("convert", {{"target_type", type}}), std::move(args));
}
std::vector<instruction_ref>
......@@ -2436,11 +2437,23 @@ struct onnx_parser
auto l = add_broadcastable_binary_op(args[0], args[1], "equal");
if(l->get_shape().type() != shape::bool_type)
{
l = prog.add_instruction(op::convert{shape::bool_type}, l);
l = prog.add_instruction(make_op("convert", {{"target_type", shape::bool_type}}), l);
}
return l;
}
instruction_ref
parse_where(const std::string&, const node_info&, std::vector<instruction_ref> args)
{
auto type = args[1]->get_shape().type();
// the operation of if cond == 1 select x; else select y,
// is equivalent to cond * (x - y) + y
auto cond = prog.add_instruction(make_op("convert", {{"target_type", type}}), args[0]);
auto diff = add_broadcastable_binary_op(args[1], args[2], "sub");
auto cd = add_broadcastable_binary_op(diff, cond, "mul");
return add_broadcastable_binary_op(cd, args[2], "add");
}
void parse_from(std::istream& is)
{
onnx::ModelProto model;
......
......@@ -100,6 +100,11 @@ migraphx::shape to_shape(const py::buffer_info& info)
t = as.type_enum();
n = sizeof(as());
}
else if(info.format == "?" and py::format_descriptor<decltype(as())>::format() == "b")
{
t = migraphx::shape::bool_type;
n = sizeof(bool);
}
});
if(n == 0)
......
......@@ -3028,3 +3028,17 @@ def variable_batch_leq_zero_test():
node = onnx.helper.make_node('Add', inputs=['0', '1'], outputs=['2'])
return ([node], [x, y], [z])
@onnx_test
def where_test():
c = helper.make_tensor_value_info('c', TensorProto.BOOL, [2])
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 1, 2, 2])
z = helper.make_tensor_value_info('z', TensorProto.FLOAT, [2, 2, 2, 2])
node = onnx.helper.make_node('Where',
inputs=['c', 'x', 'y'],
outputs=['z'])
return ([node], [c, x, y], [z])
......@@ -2198,4 +2198,25 @@ TEST_CASE(variable_batch_leq_zero_test)
EXPECT(p == prog);
}
TEST_CASE(where_test)
{
migraphx::program p;
auto lc = p.add_parameter("c", migraphx::shape{migraphx::shape::bool_type, {2}});
auto lx = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}});
auto ly = p.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 1, 2, 2}});
auto lcc = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, lc);
auto lxm = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 2, 2}}, lx);
auto lym = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 2, 2}}, ly);
auto lxy = p.add_instruction(migraphx::op::sub{}, lxm, lym);
auto lccm = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 2, 2}}, lcc);
auto lm = p.add_instruction(migraphx::op::mul{}, lxy, lccm);
auto lym1 = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 2, 2}}, ly);
auto r = p.add_instruction(migraphx::op::add{}, lm, lym1);
p.add_return({r});
auto prog = migraphx::parse_onnx("where_test.onnx");
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -127,4 +127,46 @@ TEST_CASE(gather_elements)
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(where_test)
{
migraphx::program p = migraphx::parse_onnx("where_test.onnx");
p.compile(migraphx::cpu::target{});
migraphx::shape c_shape{migraphx::shape::bool_type, {2}};
std::vector<int8_t> c_data = {1, 0};
migraphx::shape x_shape{migraphx::shape::float_type, {2, 2, 2}};
std::vector<float> x_data(8, 1.0f);
migraphx::shape y_shape{migraphx::shape::float_type, {2, 1, 2, 2}};
std::vector<float> y_data(8, 2.0f);
migraphx::program::parameter_map pp;
pp["c"] = migraphx::argument(c_shape, c_data.data());
pp["x"] = migraphx::argument(x_shape, x_data.data());
pp["y"] = migraphx::argument(y_shape, y_data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f};
EXPECT(migraphx::verify_range(result_vector, gold));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }

where_test:…

c
x
yz"Where
where_testZ
c
 
Z
x



Z
y




b
z




B
\ No newline at end of file
......@@ -151,6 +151,8 @@ def create_backend_test(testname=None, target_device=None):
backend_test.include(r'.*test_thresholdedrelu.*')
backend_test.include(r'.*test_transpose.*')
backend_test.include(r'.*test_unsqueeze.*')
backend_test.include(r'.*test_where*')
backend_test.include(r'.*test_where.*')
backend_test.include(r'.*test_ZeroPad2d*')
# # Onnx native model tests
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment