"test/vscode:/vscode.git/clone" did not exist on "8ea8473d716d4d9958998606520599edcb3e159d"
Commit bbe210f8 authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_broadcast' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_unsqueeze

parents 17abf67e 411abcec
......@@ -42,7 +42,6 @@
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/unknown.hpp>
#include <random>
#include <migraphx/serialize.hpp>
......
......@@ -118,48 +118,67 @@ TEST_CASE(broadcast)
}
}
TEST_CASE(broadcast_2in)
TEST_CASE(broadcast_axis_out_of_range_error)
{
{
migraphx::shape a_input{migraphx::shape::float_type, {4}, {1}};
migraphx::shape b_input{migraphx::shape::float_type, {4, 4}, {4, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 4}, {1, 0}},
migraphx::make_op("broadcast", {{"axis", 0}}),
a_input,
b_input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 4}, {0, 1}},
migraphx::make_op("broadcast", {{"axis", 1}}),
a_input,
b_input);
throws_shape(migraphx::make_op("broadcast", {{"axis", 2}}), a_input, b_input);
}
{
migraphx::shape a_input{migraphx::shape::float_type, {4}, {1}};
migraphx::shape b_input{migraphx::shape::float_type, {2, 2}, {2, 1}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 1}}), a_input, b_input);
}
{
migraphx::shape a_input{migraphx::shape::float_type, {4, 2}, {2, 1}};
migraphx::shape b_input{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), b_input, a_input);
}
{
std::vector<migraphx::shape::dynamic_dimension> dd{{4, 4, 0}};
migraphx::shape a_input{migraphx::shape::float_type, dd};
migraphx::shape b_input{migraphx::shape::float_type, {4, 4}, {4, 1}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), a_input, b_input);
}
{
migraphx::shape a_input{migraphx::shape::float_type, {4}, {1}};
migraphx::shape b_input{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), a_input, b_input);
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}},
migraphx::make_op("broadcast", {{"axis", 1}}),
a_input,
b_input);
throws_shape(migraphx::make_op("broadcast", {{"axis", 2}}), a_input, b_input);
}
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {1}, {0}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 4}, {"out_lens", lens}}), input);
}
TEST_CASE(broadcast_2in_static_static)
{
migraphx::shape a_input{migraphx::shape::float_type, {4}, {1}};
migraphx::shape b_input{migraphx::shape::float_type, {4, 4}, {4, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 4}, {1, 0}},
migraphx::make_op("broadcast", {{"axis", 0}}),
a_input,
b_input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 4}, {0, 1}},
migraphx::make_op("broadcast", {{"axis", 1}}),
a_input,
b_input);
throws_shape(migraphx::make_op("broadcast", {{"axis", 2}}), a_input, b_input);
}
TEST_CASE(broadcast_2in_not_matching_error)
{
migraphx::shape a_input{migraphx::shape::float_type, {4}, {1}};
migraphx::shape b_input{migraphx::shape::float_type, {2, 2}, {2, 1}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 1}}), a_input, b_input);
}
TEST_CASE(broadcast_2in_dynamic_s0_error1)
{
migraphx::shape a_input{migraphx::shape::float_type, {4, 2}, {2, 1}};
migraphx::shape b_input{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), b_input, a_input);
}
TEST_CASE(broadcast_2in_dynamic_s0_error2)
{
std::vector<migraphx::shape::dynamic_dimension> dd{{4, 4, 0}};
migraphx::shape a_input{migraphx::shape::float_type, dd};
migraphx::shape b_input{migraphx::shape::float_type, {4, 4}, {4, 1}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), a_input, b_input);
}
TEST_CASE(broadcast_2in_static_dyn)
{
migraphx::shape a_input{migraphx::shape::float_type, {4}, {1}};
migraphx::shape b_input{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), a_input, b_input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}},
migraphx::make_op("broadcast", {{"axis", 1}}),
a_input,
b_input);
throws_shape(migraphx::make_op("broadcast", {{"axis", 2}}), a_input, b_input);
}
TEST_CASE(broadcast_2in_dyn_s0_ndim_greater_than_1_error)
{
migraphx::shape a_input{migraphx::shape::float_type, {4, 2}};
migraphx::shape b_input{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), a_input, b_input);
}
TEST_CASE(convolution_shape)
......
......@@ -29,7 +29,6 @@
#include <migraphx/module.hpp>
#include <sstream>
#include <string>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
......
......@@ -31,7 +31,6 @@
#include <migraphx/instruction.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/make_op.hpp>
......
......@@ -33,7 +33,6 @@
#include <migraphx/matcher.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/apply_alpha_beta.hpp>
bool is_convolution(const migraphx::instruction& ins) { return ins.name() == "convolution"; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment