Commit d9d2215a authored by charlie's avatar charlie
Browse files

Redo design

* doesn't make much sense to make broadcast use two inputs or handle
dynamic shapes
* compute the common shape for dynamic multibroadcast in the
multibroadcast op
* multibroadcast all combinations of the dynamic inputs
parent 4d913223
...@@ -31,22 +31,6 @@ ...@@ -31,22 +31,6 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
auto compute_broadcasting = [](std::vector<std::size_t> s0, std::vector<std::size_t> s1) {
std::vector<std::size_t> out_lens(s1);
auto offset = s1.size() - s0.size();
std::transform(
s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset, [&](auto a, auto b) {
if(a != b and a != 1 and b != 1)
{
MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" + to_string_range(s0) + "} and {" +
to_string_range(s1) + "} mismatch!");
}
return std::max(a, b);
});
return out_lens;
};
// Example: // Example:
// s0 = (3,2,4,5) and s1 = (2,1,1) // s0 = (3,2,4,5) and s1 = (2,1,1)
// //
...@@ -66,22 +50,26 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0, ...@@ -66,22 +50,26 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
return s0; return s0;
if(s0.size() > s1.size()) if(s0.size() > s1.size())
s0.swap(s1); s0.swap(s1);
return compute_broadcasting(s0, s1); std::vector<std::size_t> out_lens(s1);
} auto offset = s1.size() - s0.size();
std::transform(
std::vector<std::size_t> broadcast_s0s1_lens(std::vector<std::size_t> s0, s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset, [&](auto a, auto b) {
std::vector<std::size_t> s1) if(a != b and a != 1 and b != 1)
{ {
if(s0 == s1) MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" + migraphx::to_string_range(s0) +
return s0; "} and {" + migraphx::to_string_range(s1) + "} mismatch!");
if(s0.size() > s1.size()) }
MIGRAPHX_THROW("BROADCAST_SHAPE_LENS: s0 size > s1 size and swap not allowed"); return std::max(a, b);
return compute_broadcasting(s0, s1); });
return out_lens;
} }
// Compute the common (broadcasted) dimensions of a list of fixed shapes
std::vector<std::size_t> compute_common_lens(const std::vector<shape>& shapes) std::vector<std::size_t> compute_common_lens(const std::vector<shape>& shapes)
{ {
assert(not shapes.empty()); assert(not shapes.empty());
assert(
std::none_of(shapes.cbegin(), shapes.cend(), [](auto shape) { return shape.dynamic(); }));
return transform_accumulate(shapes.begin() + 1, return transform_accumulate(shapes.begin() + 1,
shapes.end(), shapes.end(),
shapes.front().lens(), shapes.front().lens(),
...@@ -127,21 +115,44 @@ instruction_ref insert_common_op(module& m, ...@@ -127,21 +115,44 @@ instruction_ref insert_common_op(module& m,
const operation& op, const operation& op,
std::vector<instruction_ref> inputs) std::vector<instruction_ref> inputs)
{ {
// TODO update this to handle dynamic shapes if(std::any_of(
auto common = common_shape(to_shapes(inputs)); inputs.cbegin(), inputs.cend(), [](auto input) { return input->get_shape().dynamic(); }))
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { {
if(input->get_shape().lens() != common.lens()) auto c_type = compute_common_types(to_shapes(inputs));
{ // broadcast all inputs permutations
input = m.insert_instruction( std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto a_input) {
ins, make_op("multibroadcast", {{"out_lens", common.lens()}}), input); const auto& ori_input = a_input;
} // multibroadcast this input between every other input
if(input->get_shape().type() != common.type()) std::for_each(inputs.cbegin(), inputs.cend(), [&](auto b_input) {
{ if(b_input != ori_input)
input = m.insert_instruction( {
ins, make_op("convert", {{"target_type", common.type()}}), input); a_input =
} m.insert_instruction(ins, make_op("multibroadcast"), a_input, b_input);
return input; }
}); });
if(a_input->get_shape().type() != c_type)
{
a_input = m.insert_instruction(
ins, make_op("convert", {{"target_type", c_type}}), a_input);
}
});
}
else
{
auto common = common_shape(to_shapes(inputs));
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().lens() != common.lens())
{
input = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", common.lens()}}), input);
}
if(input->get_shape().type() != common.type())
{
input = m.insert_instruction(
ins, make_op("convert", {{"target_type", common.type()}}), input);
}
});
}
return m.insert_instruction(ins, op, inputs); return m.insert_instruction(ins, op, inputs);
} }
......
...@@ -59,71 +59,29 @@ struct broadcast ...@@ -59,71 +59,29 @@ struct broadcast
auto s0 = inputs.at(0); auto s0 = inputs.at(0);
auto t = s0.type(); auto t = s0.type();
if(inputs.size() == 1) std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
// the broadcast op is deprecated now, so not handling the negative
// value of axis anymore
if(axis >= broadcast_lens.size())
{ {
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0); MIGRAPHX_THROW("BROADCAST : axis is out of range");
// the broadcast op is deprecated now, so not handling the negative
// value of axis anymore
if(axis >= broadcast_lens.size())
{
MIGRAPHX_THROW("BROADCAST : axis is out of range");
}
if(broadcast_lens.size() - axis < s0.lens().size())
{
MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than s0 ndims");
}
if(not std::equal(s0.lens().begin(), s0.lens().end(), broadcast_lens.begin() + axis))
{
MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match");
}
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
shape output{t, broadcast_lens, std::move(bcast_strides)};
if(output.elements() < s0.elements())
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to s0 size");
return output;
} }
else
{
auto s1 = inputs.at(1);
if(axis >= s1.max_lens().size())
{
MIGRAPHX_THROW("BROADCAST_2in: axis is out of range of s1");
}
if(s1.max_lens().size() - axis < s0.max_lens().size())
{
MIGRAPHX_THROW("BROADCAST_2in: (s1 rank - axis) is less than s0 rank");
}
if(s0.dynamic() or s1.dynamic()) if(broadcast_lens.size() - axis < s0.lens().size())
{ {
auto bcast_max_lens = broadcast_s0s1_lens(s0.max_lens(), s1.max_lens()); MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than s0 ndims");
auto bcast_min_lens = broadcast_s0s1_lens(s0.min_lens(), s1.min_lens()); }
auto bcast_opt_lens = broadcast_s0s1_lens(s0.opt_lens(), s1.opt_lens());
std::vector<shape::dynamic_dimension> output_dyn_dims = {}; if(not std::equal(s0.lens().begin(), s0.lens().end(), broadcast_lens.begin() + axis))
for(size_t i = 0; i < bcast_max_lens.size(); ++i) {
{ MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match");
output_dyn_dims.push_back(shape::dynamic_dimension{
bcast_max_lens[i], bcast_min_lens[i], bcast_opt_lens[i]});
}
return {t, std::move(output_dyn_dims)};
}
else
{
if(not std::equal(s0.lens().begin(), s0.lens().end(), s1.lens().begin() + axis))
{
MIGRAPHX_THROW("BROADCAST_2in: when broadcasting, succeeding sizes must match");
}
auto bcast_lens = compute_broadcasted_lens(s0.lens(), s1.lens());
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
return {t, std::move(bcast_lens), std::move(bcast_strides)};
}
} }
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
shape output{t, broadcast_lens, std::move(bcast_strides)};
if(output.elements() < s0.elements())
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to s0 size");
return output;
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
......
...@@ -34,6 +34,12 @@ namespace migraphx { ...@@ -34,6 +34,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
/**
* Broadcast multiple dimensions between two tensors.
* Two versions of this operator: one input and two inputs.
* One input version uses output_lens attribute and broadcasts to it.
* Two inputs version broadcasts both inputs to the common shape at evaluation time.
*/
struct multibroadcast struct multibroadcast
{ {
std::vector<std::size_t> output_lens; std::vector<std::size_t> output_lens;
...@@ -93,24 +99,17 @@ struct multibroadcast ...@@ -93,24 +99,17 @@ struct multibroadcast
} }
else else
{ {
// two inputs
auto s1 = inputs.at(1); auto s1 = inputs.at(1);
if(s0.max_lens().size() > s1.max_lens().size())
{
MIGRAPHX_THROW("MULTIBROADCAST: s0 rank should <= s1 rank");
}
if(s0.dynamic() or s1.dynamic()) if(s0.dynamic() or s1.dynamic())
{ {
auto bcast_max_lens = broadcast_s0s1_lens(s0.max_lens(), s1.max_lens()); auto bcast_min_lens = compute_broadcasted_lens(s0.min_lens(), s1.min_lens());
auto bcast_min_lens = broadcast_s0s1_lens(s0.min_lens(), s1.min_lens()); auto bcast_max_lens = compute_broadcasted_lens(s0.max_lens(), s1.max_lens());
auto bcast_opt_lens = broadcast_s0s1_lens(s0.opt_lens(), s1.opt_lens()); auto bcast_opt_lens = compute_broadcasted_lens(s0.opt_lens(), s1.opt_lens());
return {t,
std::vector<shape::dynamic_dimension> output_dyn_dims = {}; std::move(bcast_min_lens),
for(size_t i = 0; i < bcast_max_lens.size(); ++i) std::move(bcast_max_lens),
{ std::move(bcast_opt_lens)};
output_dyn_dims.push_back(shape::dynamic_dimension{
bcast_max_lens[i], bcast_min_lens[i], bcast_opt_lens[i]});
}
return {t, std::move(output_dyn_dims)};
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment