Commit e92fef6e authored by charlie's avatar charlie
Browse files

revert broadcast and multibroadcast update

update multibroadcast shape tests
parent 553a8d02
...@@ -168,7 +168,7 @@ instruction_ref insert_common_op(module& m, ...@@ -168,7 +168,7 @@ instruction_ref insert_common_op(module& m,
if(inputs.size() != 2) if(inputs.size() != 2)
{ {
MIGRAPHX_THROW("INSERT_COMMON_OP: not handled; " + migraphx::to_string(inputs.size()) + MIGRAPHX_THROW("INSERT_COMMON_OP: not handled; " + migraphx::to_string(inputs.size()) +
"inputs, only handle two inputs"); "inputs, only handle two inputs if any are dynamic shape");
} }
auto c_type = compute_common_types(to_shapes(inputs)); auto c_type = compute_common_types(to_shapes(inputs));
...@@ -176,24 +176,21 @@ instruction_ref insert_common_op(module& m, ...@@ -176,24 +176,21 @@ instruction_ref insert_common_op(module& m,
compute_broadcasted_dyn_dims(inputs[0]->get_shape(), inputs[1]->get_shape()); compute_broadcasted_dyn_dims(inputs[0]->get_shape(), inputs[1]->get_shape());
// following should work for a static or dynamic shape // following should work for a static or dynamic shape
// TODO: compute_broadcasted_dyn_dims() is going to be called again in the multibroadcast
// compute_shape should figure out a way to get around recomputing that. Attribute in
// multibroadcast?
if(inputs[0]->get_shape().dyn_dims() != c_dyn_dims) if(inputs[0]->get_shape().dyn_dims() != c_dyn_dims)
{ {
inputs[0] = inputs[0] = m.insert_instruction(
m.insert_instruction(ins, ins,
make_op("multibroadcast", {{"out_dyn_dims", c_dyn_dims}}), make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}),
inputs[0], inputs[0],
inputs[1]); inputs[1]);
} }
if(inputs[1]->get_shape().dyn_dims() != c_dyn_dims) if(inputs[1]->get_shape().dyn_dims() != c_dyn_dims)
{ {
inputs[1] = inputs[1] = m.insert_instruction(
m.insert_instruction(ins, ins,
make_op("multibroadcast", {{"out_dyn_dims", c_dyn_dims}}), make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}),
inputs[1], inputs[1],
inputs[0]); inputs[0]);
} }
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().type() != c_type) if(input->get_shape().type() != c_type)
......
...@@ -55,38 +55,32 @@ struct broadcast ...@@ -55,38 +55,32 @@ struct broadcast
std::string name() const { return "broadcast"; } std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this, true}.has(1, 2); check_shapes{inputs, *this, true}.has(1);
auto s0 = inputs.at(0); auto s0 = inputs.at(0);
auto t = s0.type(); auto t = s0.type();
if(inputs.size() == 1) std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
// the broadcast op is deprecated now, so not handling the negative
// value of axis anymore
if(axis >= broadcast_lens.size())
{ {
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0); MIGRAPHX_THROW("BROADCAST : axis is out of range");
// the broadcast op is deprecated now, so not handling the negative }
// value of axis anymore
if(axis >= broadcast_lens.size())
{
MIGRAPHX_THROW("BROADCAST : axis is out of range");
}
if(broadcast_lens.size() - axis < s0.lens().size())
{
MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than s0 ndims");
}
if(not std::equal(s0.lens().begin(), s0.lens().end(), broadcast_lens.begin() + axis))
{
MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match");
}
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
shape output{t, broadcast_lens, std::move(bcast_strides)}; if(broadcast_lens.size() - axis < s0.lens().size())
if(output.elements() < s0.elements()) {
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to s0 size"); MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than s0 ndims");
return output;
} }
else
if(not std::equal(s0.lens().begin(), s0.lens().end(), broadcast_lens.begin() + axis))
{ {
MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match");
} }
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
shape output{t, broadcast_lens, std::move(bcast_strides)};
if(output.elements() < s0.elements())
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to s0 size");
return output;
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
......
...@@ -43,6 +43,8 @@ namespace op { ...@@ -43,6 +43,8 @@ namespace op {
struct multibroadcast struct multibroadcast
{ {
std::vector<std::size_t> output_lens; std::vector<std::size_t> output_lens;
// optional attribute
std::vector<shape::dynamic_dimension> output_dyn_dims; std::vector<shape::dynamic_dimension> output_dyn_dims;
template <class Self, class F> template <class Self, class F>
...@@ -62,7 +64,7 @@ struct multibroadcast ...@@ -62,7 +64,7 @@ struct multibroadcast
if(s0.max_lens().empty()) if(s0.max_lens().empty())
{ {
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should be > 0"); MIGRAPHX_THROW("MULTIBROADCAST: input dimensions should be > 0");
} }
auto make_bcast_strides = [&](std::vector<std::size_t> bcast_lens, std::size_t offset) { auto make_bcast_strides = [&](std::vector<std::size_t> bcast_lens, std::size_t offset) {
...@@ -81,7 +83,7 @@ struct multibroadcast ...@@ -81,7 +83,7 @@ struct multibroadcast
{ {
if(s0.lens().size() > output_lens.size()) if(s0.lens().size() > output_lens.size())
{ {
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should <= output size"); MIGRAPHX_THROW("MULTIBROADCAST: input dimensions should <= output size");
} }
auto offset = output_lens.size() - s0.lens().size(); auto offset = output_lens.size() - s0.lens().size();
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <numeric> #include <numeric>
#include <memory> #include <memory>
#include <migraphx/functional.hpp>
#include <migraphx/errors.hpp> #include <migraphx/errors.hpp>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -89,7 +90,10 @@ struct shape ...@@ -89,7 +90,10 @@ struct shape
std::size_t opt = 0; std::size_t opt = 0;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f); static auto reflect(Self& self, F f)
{
return pack(f(self.min, "min"), f(self.max, "max"), f(self.opt, "opt"));
}
bool is_fixed() const; bool is_fixed() const;
bool has_optimal() const; bool has_optimal() const;
......
...@@ -503,15 +503,11 @@ bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max; ...@@ -503,15 +503,11 @@ bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max;
bool shape::dynamic_dimension::has_optimal() const { return opt != 0; } bool shape::dynamic_dimension::has_optimal() const { return opt != 0; }
template <class Self, class F>
auto shape::dynamic_dimension::reflect(Self& self, F f)
{
return pack(f(self.min, "min"), f(self.max, "max"), f(self.opt, "opt"));
}
bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{ {
return (x.min == y.min and x.max == y.max and x.opt == y.opt); // don't check opt if both are fixed
bool check_opt = (x.is_fixed() and y.is_fixed()) ? false : true;
return (x.min == y.min and x.max == y.max and (check_opt ? x.opt == y.opt : true));
} }
bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
......
...@@ -1116,25 +1116,21 @@ TEST_CASE(multibroadcast) ...@@ -1116,25 +1116,21 @@ TEST_CASE(multibroadcast)
TEST_CASE(multibroadcast_2in) TEST_CASE(multibroadcast_2in)
{ {
// static-dyn
{ {
migraphx::shape a_shape{migraphx::shape::float_type, {4, 4}}; migraphx::shape a_shape{migraphx::shape::float_type, {4, 4}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}; std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {4, 4, 4}, {4, 4, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b}; migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape( expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}}, migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}},
migraphx::make_op("multibroadcast"), migraphx::make_op("multibroadcast"),
a_shape, a_shape,
b_shape); b_shape);
}
{
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {4, 4, 0}, {4, 4, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
migraphx::shape b_shape{migraphx::shape::float_type, {4, 4}};
expect_shape( expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}}, migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}},
migraphx::make_op("multibroadcast"), migraphx::make_op("multibroadcast"),
a_shape, b_shape,
b_shape); a_shape);
} }
{ {
migraphx::shape a_shape{migraphx::shape::float_type, {1, 6}}; migraphx::shape a_shape{migraphx::shape::float_type, {1, 6}};
...@@ -1144,70 +1140,68 @@ TEST_CASE(multibroadcast_2in) ...@@ -1144,70 +1140,68 @@ TEST_CASE(multibroadcast_2in)
migraphx::make_op("multibroadcast"), migraphx::make_op("multibroadcast"),
a_shape, a_shape,
b_shape); b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8, 0}, {6, 6, 0}}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
} }
// weirdness begins
{ {
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}}; migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 3, 0}, {6, 6, 0}}; std::vector<migraphx::shape::dynamic_dimension> b{{1, 3, 0}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b}; migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 3, 0}, {6, 6, 0}}}, throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
migraphx::make_op("multibroadcast"), throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
a_shape,
b_shape);
} }
{ {
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}}; migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {6, 6, 0}}; std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b}; migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape); throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
} }
{ {
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}}; migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 2, 0}, {6, 6, 0}}; std::vector<migraphx::shape::dynamic_dimension> b{{1, 2, 0}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b}; migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape); throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
} }
// opt handling // dyn-dyn
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 3, 2}, {6, 6, 6}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 3, 0}, {6, 6, 6}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
}
{ {
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}}; std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 3, 3}, {6, 6, 0}}; migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 4, 2}, {2, 4, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b}; migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 3, 3}, {6, 6, 0}}}, expect_shape(
migraphx::make_op("multibroadcast"), migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}},
a_shape, migraphx::make_op("multibroadcast"),
b_shape); a_shape,
b_shape);
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
} }
{ {
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}}; std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 3, 1}, {6, 6, 0}}; migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 5, 2}, {2, 4, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b}; migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 3, 3}, {6, 6, 0}}}, throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
migraphx::make_op("multibroadcast"), throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
a_shape,
b_shape);
} }
// dyn-dyn not handled
{ {
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}; std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, a}; migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}; std::vector<migraphx::shape::dynamic_dimension> b{{2, 4, 3}, {2, 4, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b}; migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape); throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
} }
// fixed-fixed // static-static
{ {
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}}; migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
migraphx::shape b_shape{migraphx::shape::float_type, {3, 6}}; migraphx::shape b_shape{migraphx::shape::float_type, {3, 6}};
...@@ -1215,6 +1209,10 @@ TEST_CASE(multibroadcast_2in) ...@@ -1215,6 +1209,10 @@ TEST_CASE(multibroadcast_2in)
migraphx::make_op("multibroadcast"), migraphx::make_op("multibroadcast"),
a_shape, a_shape,
b_shape); b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 6}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
} }
{ {
migraphx::shape a_shape{migraphx::shape::float_type, {1, 8}}; migraphx::shape a_shape{migraphx::shape::float_type, {1, 8}};
...@@ -1223,6 +1221,10 @@ TEST_CASE(multibroadcast_2in) ...@@ -1223,6 +1221,10 @@ TEST_CASE(multibroadcast_2in)
migraphx::make_op("multibroadcast"), migraphx::make_op("multibroadcast"),
a_shape, a_shape,
b_shape); b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}, {8, 1}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
} }
{ {
migraphx::shape a_shape{migraphx::shape::float_type, {8}}; migraphx::shape a_shape{migraphx::shape::float_type, {8}};
...@@ -1231,6 +1233,10 @@ TEST_CASE(multibroadcast_2in) ...@@ -1231,6 +1233,10 @@ TEST_CASE(multibroadcast_2in)
migraphx::make_op("multibroadcast"), migraphx::make_op("multibroadcast"),
a_shape, a_shape,
b_shape); b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 4, 8}, {4, 1, 0}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
} }
{ {
migraphx::shape a_shape{migraphx::shape::float_type, {3, 4, 4}}; migraphx::shape a_shape{migraphx::shape::float_type, {3, 4, 4}};
...@@ -1239,6 +1245,10 @@ TEST_CASE(multibroadcast_2in) ...@@ -1239,6 +1245,10 @@ TEST_CASE(multibroadcast_2in)
migraphx::make_op("multibroadcast"), migraphx::make_op("multibroadcast"),
a_shape, a_shape,
b_shape); b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 4}, {0, 1, 0}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
} }
{ {
migraphx::shape a_shape{migraphx::shape::float_type, {3, 1, 4}}; migraphx::shape a_shape{migraphx::shape::float_type, {3, 1, 4}};
...@@ -1247,11 +1257,16 @@ TEST_CASE(multibroadcast_2in) ...@@ -1247,11 +1257,16 @@ TEST_CASE(multibroadcast_2in)
migraphx::make_op("multibroadcast"), migraphx::make_op("multibroadcast"),
a_shape, a_shape,
b_shape); b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 4}, {0, 1, 0}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
} }
{ {
migraphx::shape a_shape{migraphx::shape::float_type, {3, 4, 4}}; migraphx::shape a_shape{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape b_shape{migraphx::shape::float_type, {4, 3}}; migraphx::shape b_shape{migraphx::shape::float_type, {4, 3}};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape); throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment