"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "3497e4da0a77e02267328a57d9203d589091d528"
Unverified Commit 4fdc4dfe authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Where op (#630)



* add the where operator

* clang format

* add where unit tests

* add where op unit test

* clang format

* add more unit tests for the where op

* clang format

* Add support for constructing value from enum

* Formatting

* add an comment about the algorithm

* call make_op to create the convert instruction
Co-authored-by: default avatarPaul <pfultz2@yahoo.com>
parent 9f283810
...@@ -171,6 +171,7 @@ struct onnx_parser ...@@ -171,6 +171,7 @@ struct onnx_parser
add_mem_op("Split", &onnx_parser::parse_split); add_mem_op("Split", &onnx_parser::parse_split);
add_mem_op("Tile", &onnx_parser::parse_tile); add_mem_op("Tile", &onnx_parser::parse_tile);
add_mem_op("Transpose", &onnx_parser::parse_transpose); add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("Where", &onnx_parser::parse_where);
// init the activation function map // init the activation function map
init_actv_func(); init_actv_func();
...@@ -2165,7 +2166,7 @@ struct onnx_parser ...@@ -2165,7 +2166,7 @@ struct onnx_parser
int to_type = parse_value(info.attributes.at("to")).at<int>(); int to_type = parse_value(info.attributes.at("to")).at<int>();
shape::type_t type = get_type(to_type); shape::type_t type = get_type(to_type);
return prog.add_instruction(op::convert{type}, std::move(args)); return prog.add_instruction(make_op("convert", {{"target_type", type}}), std::move(args));
} }
std::vector<instruction_ref> std::vector<instruction_ref>
...@@ -2436,11 +2437,23 @@ struct onnx_parser ...@@ -2436,11 +2437,23 @@ struct onnx_parser
auto l = add_broadcastable_binary_op(args[0], args[1], "equal"); auto l = add_broadcastable_binary_op(args[0], args[1], "equal");
if(l->get_shape().type() != shape::bool_type) if(l->get_shape().type() != shape::bool_type)
{ {
l = prog.add_instruction(op::convert{shape::bool_type}, l); l = prog.add_instruction(make_op("convert", {{"target_type", shape::bool_type}}), l);
} }
return l; return l;
} }
instruction_ref
parse_where(const std::string&, const node_info&, std::vector<instruction_ref> args)
{
auto type = args[1]->get_shape().type();
// the operation of if cond == 1 select x; else select y,
// is equivalent to cond * (x - y) + y
auto cond = prog.add_instruction(make_op("convert", {{"target_type", type}}), args[0]);
auto diff = add_broadcastable_binary_op(args[1], args[2], "sub");
auto cd = add_broadcastable_binary_op(diff, cond, "mul");
return add_broadcastable_binary_op(cd, args[2], "add");
}
void parse_from(std::istream& is) void parse_from(std::istream& is)
{ {
onnx::ModelProto model; onnx::ModelProto model;
......
...@@ -100,6 +100,11 @@ migraphx::shape to_shape(const py::buffer_info& info) ...@@ -100,6 +100,11 @@ migraphx::shape to_shape(const py::buffer_info& info)
t = as.type_enum(); t = as.type_enum();
n = sizeof(as()); n = sizeof(as());
} }
else if(info.format == "?" and py::format_descriptor<decltype(as())>::format() == "b")
{
t = migraphx::shape::bool_type;
n = sizeof(bool);
}
}); });
if(n == 0) if(n == 0)
......
...@@ -3028,3 +3028,17 @@ def variable_batch_leq_zero_test(): ...@@ -3028,3 +3028,17 @@ def variable_batch_leq_zero_test():
node = onnx.helper.make_node('Add', inputs=['0', '1'], outputs=['2']) node = onnx.helper.make_node('Add', inputs=['0', '1'], outputs=['2'])
return ([node], [x, y], [z]) return ([node], [x, y], [z])
@onnx_test
def where_test():
c = helper.make_tensor_value_info('c', TensorProto.BOOL, [2])
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 1, 2, 2])
z = helper.make_tensor_value_info('z', TensorProto.FLOAT, [2, 2, 2, 2])
node = onnx.helper.make_node('Where',
inputs=['c', 'x', 'y'],
outputs=['z'])
return ([node], [c, x, y], [z])
...@@ -2198,4 +2198,25 @@ TEST_CASE(variable_batch_leq_zero_test) ...@@ -2198,4 +2198,25 @@ TEST_CASE(variable_batch_leq_zero_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(where_test)
{
migraphx::program p;
auto lc = p.add_parameter("c", migraphx::shape{migraphx::shape::bool_type, {2}});
auto lx = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}});
auto ly = p.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 1, 2, 2}});
auto lcc = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, lc);
auto lxm = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 2, 2}}, lx);
auto lym = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 2, 2}}, ly);
auto lxy = p.add_instruction(migraphx::op::sub{}, lxm, lym);
auto lccm = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 2, 2}}, lcc);
auto lm = p.add_instruction(migraphx::op::mul{}, lxy, lccm);
auto lym1 = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 2, 2}}, ly);
auto r = p.add_instruction(migraphx::op::add{}, lm, lym1);
p.add_return({r});
auto prog = migraphx::parse_onnx("where_test.onnx");
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -127,4 +127,46 @@ TEST_CASE(gather_elements) ...@@ -127,4 +127,46 @@ TEST_CASE(gather_elements)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(where_test)
{
migraphx::program p = migraphx::parse_onnx("where_test.onnx");
p.compile(migraphx::cpu::target{});
migraphx::shape c_shape{migraphx::shape::bool_type, {2}};
std::vector<int8_t> c_data = {1, 0};
migraphx::shape x_shape{migraphx::shape::float_type, {2, 2, 2}};
std::vector<float> x_data(8, 1.0f);
migraphx::shape y_shape{migraphx::shape::float_type, {2, 1, 2, 2}};
std::vector<float> y_data(8, 2.0f);
migraphx::program::parameter_map pp;
pp["c"] = migraphx::argument(c_shape, c_data.data());
pp["x"] = migraphx::argument(x_shape, x_data.data());
pp["y"] = migraphx::argument(y_shape, y_data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f};
EXPECT(migraphx::verify_range(result_vector, gold));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }

where_test:…

c
x
yz"Where
where_testZ
c
 
Z
x



Z
y




b
z




B
\ No newline at end of file
...@@ -151,6 +151,8 @@ def create_backend_test(testname=None, target_device=None): ...@@ -151,6 +151,8 @@ def create_backend_test(testname=None, target_device=None):
backend_test.include(r'.*test_thresholdedrelu.*') backend_test.include(r'.*test_thresholdedrelu.*')
backend_test.include(r'.*test_transpose.*') backend_test.include(r'.*test_transpose.*')
backend_test.include(r'.*test_unsqueeze.*') backend_test.include(r'.*test_unsqueeze.*')
backend_test.include(r'.*test_where*')
backend_test.include(r'.*test_where.*')
backend_test.include(r'.*test_ZeroPad2d*') backend_test.include(r'.*test_ZeroPad2d*')
# # Onnx native model tests # # Onnx native model tests
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment