Unverified Commit 2590502c authored by Brian Pickrell's avatar Brian Pickrell Committed by GitHub
Browse files

Dynamic shape support for concat op. (#1526)

* add dynamic shape support to concat operator.  Includes new op_shape_test and ref_ops_test cases
parent 515fdfd2
...@@ -87,7 +87,7 @@ struct check_shapes ...@@ -87,7 +87,7 @@ struct check_shapes
} }
/*! /*!
* Check if the number of shape objects is equal to atleast one of the * Require the number of shape objects to equal to one of the
* given sizes. * given sizes.
* \param ns template parameter pack of sizes to check against * \param ns template parameter pack of sizes to check against
*/ */
...@@ -100,6 +100,23 @@ struct check_shapes ...@@ -100,6 +100,23 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Require the number of shape objects to equal at least a given amount. Use this
* method for ops that can take any number (variadic) of inputs.
* \param n min. number of shapes
*/
const check_shapes& has_at_least(std::size_t n) const
{
if(this->size() < n)
MIGRAPHX_THROW(prefix() + "Wrong number of arguments: expected at least " +
to_string(n) + " but given " + std::to_string(size()));
return *this;
}
/*!
* Require all shapes to have the same number of elements.
* \param n number of
*/
const check_shapes& nelements(std::size_t n) const const check_shapes& nelements(std::size_t n) const
{ {
if(not this->all_of([&](const shape& s) { return s.elements() == n; })) if(not this->all_of([&](const shape& s) { return s.elements() == n; }))
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <array> #include <array>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
...@@ -73,49 +74,87 @@ struct concat ...@@ -73,49 +74,87 @@ struct concat
} }
return offsets; return offsets;
} }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
if(inputs.empty()) // inputs can contain 1 or more shapes (variadic). compute_shape_op ensures there must
// be at least 1.
check_shapes{inputs, *this, true}.same_ndims().same_type();
if(std::none_of(inputs.begin(), inputs.end(), [&](const shape& s) { return s.dynamic(); }))
{ {
MIGRAPHX_THROW("CONCAT: Number of input tensors should exceed 0"); // Static input shapes
const auto& first_shape_lens = inputs.front().lens();
const auto& type = inputs.front().type();
for(std::size_t ll = 0; ll < first_shape_lens.size(); ll++)
{
if(ll != axis)
{
if(not std::all_of(inputs.begin(), inputs.end(), [&](auto s) {
return s.lens()[ll] == first_shape_lens[ll];
}))
{
MIGRAPHX_THROW("CONCAT: all input dimensions should match along axis " +
std::to_string(ll));
}
}
}
std::size_t new_dim_axis = 0;
for(const auto& input : inputs)
{
const auto& lens = input.lens();
new_dim_axis += lens[axis];
}
std::vector<std::size_t> new_lens = first_shape_lens;
new_lens[axis] = new_dim_axis;
return shape::from_permutation(type, new_lens, find_permutation(inputs));
} }
else if(std::all_of(
const auto& first_shape_lens = inputs.front().lens(); inputs.begin(), inputs.end(), [&](const shape& s) { return s.dynamic(); }))
const auto& type = inputs.front().type();
for(std::size_t l = 0; l < first_shape_lens.size(); l++)
{ {
if(l != axis) // Dynamic input shapes
for(std::size_t index = 0; index < inputs[0].ndim(); index++)
{ {
if(not std::all_of(inputs.begin(), inputs.end(), [&](auto s) { if(index != axis)
return s.lens()[l] == first_shape_lens[l];
}))
{ {
MIGRAPHX_THROW("CONCAT: Non-axis dimensions should match"); if(not std::all_of(inputs.begin(), inputs.end(), [&](const shape& s) {
return s.dyn_dims()[index] == inputs[0].dyn_dims()[index];
}))
MIGRAPHX_THROW("CONCAT: all input dimensions should match in axis " +
std::to_string(index));
} }
} }
std::size_t new_min = 0;
std::size_t new_max = 0;
for(const auto& input : inputs)
{
auto ddim = input.dyn_dims()[axis];
new_min += ddim.min;
new_max += ddim.max;
}
auto new_dims = inputs[0].dyn_dims();
new_dims[axis] = migraphx::shape::dynamic_dimension{new_min, new_max, 0};
return {inputs[0].type(), new_dims};
} }
std::size_t new_dim_axis = 0; else
for(const auto& input : inputs)
{ {
const auto& lens = input.lens(); MIGRAPHX_THROW("CONCAT: Cannot mix static and dynamic input shapes.");
new_dim_axis += lens[axis];
} }
std::vector<std::size_t> new_lens;
std::copy(first_shape_lens.begin(), first_shape_lens.end(), std::back_inserter(new_lens));
new_lens[axis] = new_dim_axis;
return shape::from_permutation(type, new_lens, find_permutation(inputs));
} }
argument compute(const shape& output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
std::vector<std::size_t> coffsets = compute_offsets(output_shape, args); std::vector<std::size_t> coffsets = compute_offsets(dyn_out.computed_shape, args);
for(std::size_t l = 0; l < args.size(); l++) for(std::size_t l = 0; l < args.size(); l++)
{ {
auto argl = args[l]; auto argl = args[l];
visit_all(result, argl)([&](auto output, auto input) { visit_all(result, argl)([&](auto output, auto input) {
auto slice_shape = auto slice_shape = shape{dyn_out.computed_shape.type(),
shape{output_shape.type(), input.get_shape().lens(), output_shape.strides()}; input.get_shape().lens(),
auto slice = make_view(slice_shape, output.data() + coffsets[l]); dyn_out.computed_shape.strides()};
auto slice = make_view(slice_shape, output.data() + coffsets[l]);
std::copy(input.begin(), input.end(), slice.begin()); std::copy(input.begin(), input.end(), slice.begin());
}); });
} }
......
...@@ -141,6 +141,8 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs) ...@@ -141,6 +141,8 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs)) -> decltype(x.normalize_compute_shape(inputs))
{ {
dependent_type<operation, T> y = x; dependent_type<operation, T> y = x;
if(inputs.empty())
MIGRAPHX_THROW("At least one input is required for " + x.name());
normalize_attributes(y, inputs[0].max_lens()); normalize_attributes(y, inputs[0].max_lens());
return any_cast<T>(y).normalize_compute_shape(inputs); return any_cast<T>(y).normalize_compute_shape(inputs);
} }
......
...@@ -759,6 +759,22 @@ def concat_test(): ...@@ -759,6 +759,22 @@ def concat_test():
return ([node], [x, y], [z]) return ([node], [x, y], [z])
@onnx_test()
def concat_dyn_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [None, None, 3])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [None, None, 3])
z = helper.make_tensor_value_info('2', TensorProto.FLOAT, [None, None, 3])
node = onnx.helper.make_node(
'Concat',
inputs=['0', '1'],
axis=0,
outputs=['2'],
)
return ([node], [x, y], [z])
@onnx_test() @onnx_test()
def constant_test(): def constant_test():
x = np.array([0, 1, 2]) x = np.array([0, 1, 2])
......
...@@ -840,6 +840,25 @@ TEST_CASE(concat_test) ...@@ -840,6 +840,25 @@ TEST_CASE(concat_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(concat_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"0", migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {1, 4, 0}, {3, 3, 0}}});
auto l1 = mm->add_parameter(
"1", migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {1, 4, 0}, {3, 3, 0}}});
auto ret = mm->add_instruction(migraphx::make_op("concat"), l0, l1);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
auto prog = parse_onnx("concat_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(constant_test) TEST_CASE(constant_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -3215,4 +3215,52 @@ TEST_CASE(roialign_test) ...@@ -3215,4 +3215,52 @@ TEST_CASE(roialign_test)
throws_shape(migraphx::make_op("roialign"), sx, srois2, sbi); throws_shape(migraphx::make_op("roialign"), sx, srois2, sbi);
} }
TEST_CASE(test_concat)
{
migraphx::shape sx{migraphx::shape::float_type, {3, 4, 5, 6}};
migraphx::shape sy{migraphx::shape::float_type, {3, 4, 1, 6}};
migraphx::shape sout{migraphx::shape::float_type, {3, 4, 6, 6}};
expect_shape(sout, migraphx::make_op("concat", {{"axis", 2}}), sx, sy);
// axis out of range
throws_shape(migraphx::make_op("concat", {{"axis", 11}}), sx, sy);
// 1 input; no-op
expect_shape(sx, migraphx::make_op("concat", {{"axis", 2}}), sx);
// rank doesn't match
migraphx::shape sbi1{migraphx::shape::int64_type, {2, 3}};
throws_shape(migraphx::make_op("concat", {{"axis", 0}}), sx, sbi1);
// non-matching dimension 2
throws_shape(migraphx::make_op("concat", {{"axis", 1}}), sx, sy);
// no input shapes (at least one is required)
throws_shape(migraphx::make_op("concat", {{"axis", 0}}));
}
TEST_CASE(test_dyn_concat)
{
migraphx::shape sx{migraphx::shape::float_type, {{1, 3, 3}, {4, 4}, {1, 5, 5}, {6, 6}}};
migraphx::shape sy{migraphx::shape::float_type, {{1, 3, 3}, {4, 4}, {1, 4, 4}, {6, 6}}};
migraphx::shape sout{migraphx::shape::float_type, {{1, 3, 3}, {4, 4, 0}, {2, 9, 0}, {6, 6}}};
expect_shape(sout, migraphx::make_op("concat", {{"axis", 2}}), sx, sy);
// axis out of range
throws_shape(migraphx::make_op("concat", {{"axis", 4}}), sx, sy);
// rank doesn't match
migraphx::shape srank{migraphx::shape::int64_type, {{1, 3, 3}, {4, 4}}};
throws_shape(migraphx::make_op("concat", {{"axis", 0}}), sx, srank);
// non-matching dimension 2
throws_shape(migraphx::make_op("concat", {{"axis", 1}}), sx, sy);
// static and dynamic shapes together
migraphx::shape sstat{migraphx::shape::float_type, {3, 4, 1, 6}};
throws_shape(migraphx::make_op("concat", {{"axis", 2}}), sx, sstat);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -953,6 +953,41 @@ TEST_CASE(concat_test) ...@@ -953,6 +953,41 @@ TEST_CASE(concat_test)
} }
} }
TEST_CASE(concat_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
int axis = 0;
migraphx::shape s0{migraphx::shape::int32_type, {{2, 4, 2}, {2, 3, 2}}};
migraphx::shape s1{migraphx::shape::int32_type, {{3, 4, 4}, {2, 3, 2}}};
migraphx::shape s2{migraphx::shape::int32_type, {{1, 5, 3}, {2, 3, 2}}};
auto input0 = mm->add_parameter("X", s0);
auto input1 = mm->add_parameter("Y", s1);
auto input2 = mm->add_parameter("Z", s2);
mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), input0, input1, input2);
p.compile(migraphx::ref::target{});
migraphx::shape static_shape0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape static_shape1{migraphx::shape::int32_type, {3, 2}};
migraphx::shape static_shape2{migraphx::shape::int32_type, {1, 2}};
std::vector<int> data0 = {0, 1, 2, 3};
std::vector<int> data1 = {4, 5, 6, 7, 8, 9};
std::vector<int> data2 = {10, 11};
migraphx::parameter_map params;
params["X"] = migraphx::argument(static_shape0, data0.data());
params["Y"] = migraphx::argument(static_shape1, data1.data());
params["Z"] = migraphx::argument(static_shape2, data2.data());
auto result = p.eval(params).back();
std::vector<int> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
EXPECT(migraphx::verify_range(results_vector, gold));
EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector<std::size_t>({6, 2})));
}
TEST_CASE(contiguous_test) TEST_CASE(contiguous_test)
{ {
migraphx::shape a_shape{migraphx::shape::float_type, {1, 3, 2, 2}, {12, 1, 6, 3}}; migraphx::shape a_shape{migraphx::shape::float_type, {1, 3, 2, 2}, {12, 1, 6, 3}};
......
...@@ -140,6 +140,8 @@ template <class T> ...@@ -140,6 +140,8 @@ template <class T>
auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs) auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs)) -> decltype(x.normalize_compute_shape(inputs))
{ {
if(inputs.empty())
MIGRAPHX_THROW("At least one input is required for " + x.name());
dependent_type<operation, T> y = x; dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].max_lens()); normalize_attributes(y, inputs[0].max_lens());
return any_cast<T>(y).normalize_compute_shape(inputs); return any_cast<T>(y).normalize_compute_shape(inputs);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment