Commit bbe210f8 authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_broadcast' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_unsqueeze

parents 17abf67e 411abcec
...@@ -42,7 +42,6 @@ ...@@ -42,7 +42,6 @@
#include <migraphx/op/lrn.hpp> #include <migraphx/op/lrn.hpp>
#include <migraphx/op/reshape.hpp> #include <migraphx/op/reshape.hpp>
#include <migraphx/op/unknown.hpp> #include <migraphx/op/unknown.hpp>
#include <random>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
......
...@@ -118,48 +118,67 @@ TEST_CASE(broadcast) ...@@ -118,48 +118,67 @@ TEST_CASE(broadcast)
} }
} }
TEST_CASE(broadcast_2in) TEST_CASE(broadcast_axis_out_of_range_error)
{ {
{ std::vector<std::size_t> lens{1, 1};
migraphx::shape a_input{migraphx::shape::float_type, {4}, {1}}; migraphx::shape input{migraphx::shape::float_type, {1}, {0}};
migraphx::shape b_input{migraphx::shape::float_type, {4, 4}, {4, 1}}; throws_shape(migraphx::make_op("broadcast", {{"axis", 4}, {"out_lens", lens}}), input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 4}, {1, 0}}, }
migraphx::make_op("broadcast", {{"axis", 0}}),
a_input, TEST_CASE(broadcast_2in_static_static)
b_input); {
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 4}, {0, 1}}, migraphx::shape a_input{migraphx::shape::float_type, {4}, {1}};
migraphx::make_op("broadcast", {{"axis", 1}}), migraphx::shape b_input{migraphx::shape::float_type, {4, 4}, {4, 1}};
a_input, expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 4}, {1, 0}},
b_input); migraphx::make_op("broadcast", {{"axis", 0}}),
throws_shape(migraphx::make_op("broadcast", {{"axis", 2}}), a_input, b_input); a_input,
} b_input);
{ expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 4}, {0, 1}},
migraphx::shape a_input{migraphx::shape::float_type, {4}, {1}}; migraphx::make_op("broadcast", {{"axis", 1}}),
migraphx::shape b_input{migraphx::shape::float_type, {2, 2}, {2, 1}}; a_input,
throws_shape(migraphx::make_op("broadcast", {{"axis", 1}}), a_input, b_input); b_input);
} throws_shape(migraphx::make_op("broadcast", {{"axis", 2}}), a_input, b_input);
{ }
migraphx::shape a_input{migraphx::shape::float_type, {4, 2}, {2, 1}};
migraphx::shape b_input{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}}; TEST_CASE(broadcast_2in_not_matching_error)
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), b_input, a_input); {
} migraphx::shape a_input{migraphx::shape::float_type, {4}, {1}};
{ migraphx::shape b_input{migraphx::shape::float_type, {2, 2}, {2, 1}};
std::vector<migraphx::shape::dynamic_dimension> dd{{4, 4, 0}}; throws_shape(migraphx::make_op("broadcast", {{"axis", 1}}), a_input, b_input);
migraphx::shape a_input{migraphx::shape::float_type, dd}; }
migraphx::shape b_input{migraphx::shape::float_type, {4, 4}, {4, 1}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), a_input, b_input); TEST_CASE(broadcast_2in_dynamic_s0_error1)
} {
{ migraphx::shape a_input{migraphx::shape::float_type, {4, 2}, {2, 1}};
migraphx::shape a_input{migraphx::shape::float_type, {4}, {1}}; migraphx::shape b_input{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}};
migraphx::shape b_input{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}}; throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), b_input, a_input);
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), a_input, b_input); }
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}}, TEST_CASE(broadcast_2in_dynamic_s0_error2)
migraphx::make_op("broadcast", {{"axis", 1}}), {
a_input, std::vector<migraphx::shape::dynamic_dimension> dd{{4, 4, 0}};
b_input); migraphx::shape a_input{migraphx::shape::float_type, dd};
throws_shape(migraphx::make_op("broadcast", {{"axis", 2}}), a_input, b_input); migraphx::shape b_input{migraphx::shape::float_type, {4, 4}, {4, 1}};
} throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), a_input, b_input);
}
TEST_CASE(broadcast_2in_static_dyn)
{
migraphx::shape a_input{migraphx::shape::float_type, {4}, {1}};
migraphx::shape b_input{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), a_input, b_input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}},
migraphx::make_op("broadcast", {{"axis", 1}}),
a_input,
b_input);
throws_shape(migraphx::make_op("broadcast", {{"axis", 2}}), a_input, b_input);
}
TEST_CASE(broadcast_2in_dyn_s0_ndim_greater_than_1_error)
{
migraphx::shape a_input{migraphx::shape::float_type, {4, 2}};
migraphx::shape b_input{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), a_input, b_input);
} }
TEST_CASE(convolution_shape) TEST_CASE(convolution_shape)
......
...@@ -29,7 +29,6 @@ ...@@ -29,7 +29,6 @@
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
......
...@@ -31,7 +31,6 @@ ...@@ -31,7 +31,6 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/ref/target.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
......
...@@ -33,7 +33,6 @@ ...@@ -33,7 +33,6 @@
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/apply_alpha_beta.hpp> #include <migraphx/apply_alpha_beta.hpp>
bool is_convolution(const migraphx::instruction& ins) { return ins.name() == "convolution"; } bool is_convolution(const migraphx::instruction& ins) { return ins.name() == "convolution"; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment