Unverified Commit 515fdfd2 authored by Brian Pickrell's avatar Brian Pickrell Committed by GitHub
Browse files

Dynamic shape support for Where op. (#1528)

dyn shape support for Where operator.  Includes shape test, ref_ops test, onx_test. 
parent 27c0fe35
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -42,9 +42,17 @@ struct where ...@@ -42,9 +42,17 @@ struct where
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(3).same_dims(); check_shapes{inputs, *this, true}.has(3).same_dims();
auto s1 = inputs.at(1); auto s1 = inputs.at(1);
auto s2 = inputs.at(2); auto s2 = inputs.at(2);
if(s1.dynamic() or s2.dynamic())
{
if(s1 == s2)
return s1;
MIGRAPHX_THROW("WHERE: dynamic input shapes must be the same");
}
// Compare two static shapes, returning a standard shape
if(s1 == s2 and s1.packed()) if(s1 == s2 and s1.packed())
{ {
return s1; return s1;
...@@ -63,12 +71,12 @@ struct where ...@@ -63,12 +71,12 @@ struct where
} }
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
visit_all(result, args[1], args[2])([&](auto output, const auto x, const auto y) { visit_all(result, args[1], args[2])([&](auto output, const auto x, const auto y) {
args[0].visit([&](const auto condition) { args[0].visit([&](const auto condition) {
par_for(output_shape.elements(), par_for(dyn_out.computed_shape.elements(),
[&](auto i) { output[i] = condition[i] ? x[i] : y[i]; }); [&](auto i) { output[i] = condition[i] ? x[i] : y[i]; });
}); });
}); });
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -40,28 +40,44 @@ struct parse_where : op_parser<parse_where> ...@@ -40,28 +40,44 @@ struct parse_where : op_parser<parse_where>
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
auto lens = // TODO: broadcasting for dynamic shapes is only implemented
compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens()); // for binary ops at time of writing, not ternary ops.
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens()); // When it becomes available, add multibroadcasting steps in the dynamic shape case.
if(args[0]->get_shape().lens() != lens) // For now for dynamic shapes, just insert the Where op. All shapes must be the
// same for it to succeed.
if(std::all_of(args.begin(), args.end(), [](auto v) { return v->get_shape().dynamic(); }))
{ {
args[0] = return info.add_instruction(make_op("where"), args[0], args[1], args[2]);
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[0]);
} }
else if(std::none_of(
if(args[1]->get_shape().lens() != lens) args.begin(), args.end(), [](auto v) { return v->get_shape().dynamic(); }))
{ {
args[1] = // If shapes are static and any are broadcasted, insert multibroadcast ops
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[1]); auto lens =
} compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens());
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
if(args[0]->get_shape().lens() != lens)
{
args[0] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[0]);
}
if(args[2]->get_shape().lens() != lens) if(args[1]->get_shape().lens() != lens)
{ {
args[2] = args[1] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[2]); info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[1]);
} }
if(args[2]->get_shape().lens() != lens)
{
args[2] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[2]);
}
return info.add_instruction(make_op("where"), args[0], args[1], args[2]); return info.add_instruction(make_op("where"), args[0], args[1], args[2]);
}
else
MIGRAPHX_THROW("PARSE_WHERE: doesn't support mixed static and dynamic shape inputs");
} }
}; };
......
...@@ -7288,3 +7288,32 @@ def where_test(): ...@@ -7288,3 +7288,32 @@ def where_test():
outputs=['z']) outputs=['z'])
return ([node], [c, x, y], [z]) return ([node], [c, x, y], [z])
@onnx_test()
def where_dyn_test():
c = helper.make_tensor_value_info('c', TensorProto.BOOL, [None, 2, 2])
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [None, 2, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [None, 2, 2])
z = helper.make_tensor_value_info('z', TensorProto.FLOAT, [None, 2, 2])
node = onnx.helper.make_node('Where',
inputs=['c', 'x', 'y'],
outputs=['z'])
return ([node], [c, x, y], [z])
@onnx_test()
def where_mixed_test():
# mixture of static and dynamic input shapes is not supported
c = helper.make_tensor_value_info('c', TensorProto.BOOL, [None, 2, 2])
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [None, 2, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 2, 2])
z = helper.make_tensor_value_info('z', TensorProto.FLOAT, [None, 2, 2])
node = onnx.helper.make_node('Where',
inputs=['c', 'x', 'y'],
outputs=['z'])
return ([node], [c, x, y], [z])
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -6948,4 +6948,35 @@ TEST_CASE(where_test) ...@@ -6948,4 +6948,35 @@ TEST_CASE(where_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(where_dyn_test)
{
// TODO: broadcasting for dynamic shapes isn't implemented at time of writing.
// Update this test case to use shapes that require broadcasting, when available.
migraphx::program p;
auto* mm = p.get_main_module();
auto lc = mm->add_parameter(
"c", migraphx::shape{migraphx::shape::bool_type, {{1, 4, 0}, {2, 2, 0}, {2, 2, 0}}});
auto lx = mm->add_parameter(
"x", migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 2, 0}, {2, 2, 0}}});
auto ly = mm->add_parameter(
"y", migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 2, 0}, {2, 2, 0}}});
auto r = mm->add_instruction(migraphx::make_op("where"), lc, lx, ly);
mm->add_return({r});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
auto prog = parse_onnx("where_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(where_mixed_test)
{
// mixture of static and dynamic input shapes is not supported
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
EXPECT(test::throws([&] { migraphx::parse_onnx("where_mixed_test.onnx", options); }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -3157,6 +3157,42 @@ TEST_CASE(where_broadcast_input) ...@@ -3157,6 +3157,42 @@ TEST_CASE(where_broadcast_input)
expect_shape(s2, migraphx::make_op("where"), s3, s1, s2); expect_shape(s2, migraphx::make_op("where"), s3, s1, s2);
} }
TEST_CASE(where_dyn_input0)
{
// dynamic shapes not the same
migraphx::shape s1{migraphx::shape::float_type, {{2, 3, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{2, 3, 0}, {2, 3, 0}}};
migraphx::shape s3{migraphx::shape::bool_type, {2, 2}};
throws_shape(migraphx::make_op("where"), s3, s1, s2);
}
TEST_CASE(where_dyn_input1)
{
// mixed static/dynamic inputs (not allowed)
migraphx::shape s1{migraphx::shape::float_type, {2, 2}, {2, 1}};
migraphx::shape s2{migraphx::shape::float_type, {{2, 2, 0}, {2, 2, 0}}};
migraphx::shape s3{migraphx::shape::bool_type, {2, 2}, {2, 1}};
throws_shape(migraphx::make_op("where"), s3, s1, s2);
}
TEST_CASE(where_dyn_input2)
{
// dynamic shapes
migraphx::shape s1{migraphx::shape::float_type, {{2, 3, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{2, 3, 0}, {3, 3, 0}}};
migraphx::shape s3{migraphx::shape::bool_type, {{2, 3, 0}, {3, 3, 0}}};
expect_shape(s2, migraphx::make_op("where"), s3, s1, s2);
}
TEST_CASE(where_dyn_input3)
{
// dynamic shapes, predicate shape is different
migraphx::shape s1{migraphx::shape::float_type, {{2, 3, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{2, 3, 0}, {3, 3, 0}}};
migraphx::shape s3{migraphx::shape::bool_type, {{2, 3, 0}, {3, 4, 0}}};
throws_shape(migraphx::make_op("where"), s3, s1, s2);
}
TEST_CASE(roialign_test) TEST_CASE(roialign_test)
{ {
migraphx::shape sx{migraphx::shape::float_type, {3, 4, 5, 6}}; migraphx::shape sx{migraphx::shape::float_type, {3, 4, 5, 6}};
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -8110,6 +8110,37 @@ TEST_CASE(where_test) ...@@ -8110,6 +8110,37 @@ TEST_CASE(where_test)
EXPECT(migraphx::verify_range(result_vec, gold)); EXPECT(migraphx::verify_range(result_vec, gold));
} }
TEST_CASE(where_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sb{migraphx::shape::bool_type, {{2, 3, 0}, {2, 3, 0}}};
migraphx::shape sx{migraphx::shape::float_type, {{2, 3, 0}, {2, 3, 0}}};
auto lb = mm->add_parameter("predicate", sb);
auto lx = mm->add_parameter("X", sx);
auto ly = mm->add_parameter("Y", sx);
mm->add_instruction(migraphx::make_op("where"), lb, lx, ly);
p.compile(migraphx::ref::target{});
std::vector<char> b{1, 1, 1, 0, 0, 0, 1, 0, 1};
std::vector<float> x(9, 1.0);
std::vector<float> y(9, 2.0);
migraphx::parameter_map params;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3, 3}};
migraphx::shape input_fixed_shape1{migraphx::shape::uint8_type, {3, 3}};
params["X"] = migraphx::argument(input_fixed_shape0, x.data());
params["Y"] = migraphx::argument(input_fixed_shape0, y.data());
params["predicate"] = migraphx::argument(input_fixed_shape1, b.data());
auto result = p.eval(params).back();
std::vector<float> results_vector(3 * 3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 1, 1, 2, 2, 2, 1, 2, 1};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(where_broadcasted_inputs_test) TEST_CASE(where_broadcasted_inputs_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment