Commit e92fef6e authored by charlie's avatar charlie
Browse files

revert broadcast and multibroadcast update

update multibroadcast shape tests
parent 553a8d02
......@@ -168,7 +168,7 @@ instruction_ref insert_common_op(module& m,
if(inputs.size() != 2)
{
MIGRAPHX_THROW("INSERT_COMMON_OP: not handled; " + migraphx::to_string(inputs.size()) +
"inputs, only handle two inputs");
"inputs, only handle two inputs if any are dynamic shape");
}
auto c_type = compute_common_types(to_shapes(inputs));
......@@ -176,22 +176,19 @@ instruction_ref insert_common_op(module& m,
compute_broadcasted_dyn_dims(inputs[0]->get_shape(), inputs[1]->get_shape());
// following should work for a static or dynamic shape
// TODO: compute_broadcasted_dyn_dims() is going to be called again in the multibroadcast
// compute_shape should figure out a way to get around recomputing that. Attribute in
// multibroadcast?
if(inputs[0]->get_shape().dyn_dims() != c_dyn_dims)
{
inputs[0] =
m.insert_instruction(ins,
make_op("multibroadcast", {{"out_dyn_dims", c_dyn_dims}}),
inputs[0] = m.insert_instruction(
ins,
make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}),
inputs[0],
inputs[1]);
}
if(inputs[1]->get_shape().dyn_dims() != c_dyn_dims)
{
inputs[1] =
m.insert_instruction(ins,
make_op("multibroadcast", {{"out_dyn_dims", c_dyn_dims}}),
inputs[1] = m.insert_instruction(
ins,
make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}),
inputs[1],
inputs[0]);
}
......
......@@ -55,11 +55,9 @@ struct broadcast
std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1, 2);
check_shapes{inputs, *this, true}.has(1);
auto s0 = inputs.at(0);
auto t = s0.type();
if(inputs.size() == 1)
{
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
// the broadcast op is deprecated now, so not handling the negative
// value of axis anymore
......@@ -84,10 +82,6 @@ struct broadcast
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to s0 size");
return output;
}
else
{
}
}
argument compute(shape output_shape, std::vector<argument> args) const
{
......
......@@ -43,6 +43,8 @@ namespace op {
struct multibroadcast
{
std::vector<std::size_t> output_lens;
// optional attribute
std::vector<shape::dynamic_dimension> output_dyn_dims;
template <class Self, class F>
......@@ -62,7 +64,7 @@ struct multibroadcast
if(s0.max_lens().empty())
{
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should be > 0");
MIGRAPHX_THROW("MULTIBROADCAST: input dimensions should be > 0");
}
auto make_bcast_strides = [&](std::vector<std::size_t> bcast_lens, std::size_t offset) {
......@@ -81,7 +83,7 @@ struct multibroadcast
{
if(s0.lens().size() > output_lens.size())
{
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should <= output size");
MIGRAPHX_THROW("MULTIBROADCAST: input dimensions should <= output size");
}
auto offset = output_lens.size() - s0.lens().size();
......
......@@ -30,6 +30,7 @@
#include <numeric>
#include <memory>
#include <migraphx/functional.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/half.hpp>
#include <migraphx/config.hpp>
......@@ -89,7 +90,10 @@ struct shape
std::size_t opt = 0;
template <class Self, class F>
static auto reflect(Self& self, F f);
static auto reflect(Self& self, F f)
{
return pack(f(self.min, "min"), f(self.max, "max"), f(self.opt, "opt"));
}
bool is_fixed() const;
bool has_optimal() const;
......
......@@ -503,15 +503,11 @@ bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max;
bool shape::dynamic_dimension::has_optimal() const { return opt != 0; }
template <class Self, class F>
auto shape::dynamic_dimension::reflect(Self& self, F f)
{
return pack(f(self.min, "min"), f(self.max, "max"), f(self.opt, "opt"));
}
bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{
return (x.min == y.min and x.max == y.max and x.opt == y.opt);
// don't check opt if both are fixed
bool check_opt = (x.is_fixed() and y.is_fixed()) ? false : true;
return (x.min == y.min and x.max == y.max and (check_opt ? x.opt == y.opt : true));
}
bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
......
......@@ -1116,25 +1116,21 @@ TEST_CASE(multibroadcast)
TEST_CASE(multibroadcast_2in)
{
// static-dyn
{
migraphx::shape a_shape{migraphx::shape::float_type, {4, 4}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {4, 4, 0}, {4, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {4, 4, 4}, {4, 4, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
}
{
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {4, 4, 0}, {4, 4, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
migraphx::shape b_shape{migraphx::shape::float_type, {4, 4}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
b_shape,
a_shape);
}
{
migraphx::shape a_shape{migraphx::shape::float_type, {1, 6}};
......@@ -1144,70 +1140,68 @@ TEST_CASE(multibroadcast_2in)
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8, 0}, {6, 6, 0}}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}
// weirdness begins
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 3, 0}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 3, 0}, {6, 6, 0}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
}
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
}
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 2, 0}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
}
// opt handling
// dyn-dyn
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 3, 2}, {6, 6, 6}};
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 4, 2}, {2, 4, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 3, 0}, {6, 6, 6}}},
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
}
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 3, 3}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 3, 3}, {6, 6, 0}}},
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
b_shape,
a_shape);
}
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 3, 1}, {6, 6, 0}};
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 5, 2}, {2, 4, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 3, 3}, {6, 6, 0}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
}
// dyn-dyn not handled
{
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 4, 3}, {2, 4, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
}
// fixed-fixed
// static-static
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
migraphx::shape b_shape{migraphx::shape::float_type, {3, 6}};
......@@ -1215,6 +1209,10 @@ TEST_CASE(multibroadcast_2in)
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 6}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}
{
migraphx::shape a_shape{migraphx::shape::float_type, {1, 8}};
......@@ -1223,6 +1221,10 @@ TEST_CASE(multibroadcast_2in)
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}, {8, 1}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}
{
migraphx::shape a_shape{migraphx::shape::float_type, {8}};
......@@ -1231,6 +1233,10 @@ TEST_CASE(multibroadcast_2in)
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 4, 8}, {4, 1, 0}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 4, 4}};
......@@ -1239,6 +1245,10 @@ TEST_CASE(multibroadcast_2in)
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 4}, {0, 1, 0}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 1, 4}};
......@@ -1247,11 +1257,16 @@ TEST_CASE(multibroadcast_2in)
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 4}, {0, 1, 0}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape b_shape{migraphx::shape::float_type, {4, 3}};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment