Commit c539a7b0 authored by charlie's avatar charlie
Browse files

Refactor into precomputing dyn output shape

also adding limitations on broadcasting dynamic shapes
parent 5fc6afe6
......@@ -66,33 +66,50 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
return out_lens;
}
// Handling opt dyn_dims calculation
std::vector<std::size_t> compute_broadcasted_opt_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1)
std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, shape s1)
{
if(s0 == s1)
return s0;
if(s0.size() > s1.size())
s0.swap(s1);
std::vector<std::size_t> out_lens(s1);
auto offset = s1.size() - s0.size();
if(s0.dynamic() or s1.dynamic())
{
// change both shapes to dynamic_dimension representation
if(not s0.dynamic())
s0 = s0.to_dynamic();
if(not s1.dynamic())
s1 = s1.to_dynamic();
if(s0.rank() > s1.rank())
{
std::swap(s0, s1);
}
auto offset = s1.rank() - s0.rank();
std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims());
std::vector<shape::dynamic_dimension> one_dyn_dims{{1, 1, 0}, {1, 1, 1}};
std::transform(
s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset, [&](auto a, auto b) {
s0.dyn_dims().cbegin(),
s0.dyn_dims().cend(),
s1.dyn_dims().cbegin() + offset,
out_dims.begin() + offset,
[&](auto a, auto b) {
if(a == b)
{
return a;
}
else if((a == 1 or b == 1) and a != 0 and b != 0)
else if(contains(one_dyn_dims, a) or contains(one_dyn_dims, b))
{
return std::max(a, b);
return shape::dynamic_dimension{
std::max(a.min, b.min), std::max(a.max, b.max), std::max(a.opt, b.opt)};
}
else
{
// if not matching nor 1, set to 0
return static_cast<std::size_t>(0);
MIGRAPHX_THROW("COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {" +
migraphx::to_string_range(s0.dyn_dims()) + "} and {" +
migraphx::to_string_range(s1.dyn_dims()) + "} mismatch!");
}
});
return out_lens;
}
else
{
MIGRAPHX_THROW("COMPUTE_BROADCASTED_DYN_DIMS: given two static shapes");
}
}
// Compute the common (broadcasted) dimensions of a list of fixed shapes
......@@ -149,24 +166,36 @@ instruction_ref insert_common_op(module& m,
if(std::any_of(
inputs.cbegin(), inputs.cend(), [](auto input) { return input->get_shape().dynamic(); }))
{
// currently only handles the binary case
if(inputs.size() != 2)
{
MIGRAPHX_THROW("INSERT_COMMON_OP: not handled; " + migraphx::to_string(inputs.size()) +
"inputs, only handle two inputs");
}
auto c_type = compute_common_types(to_shapes(inputs));
// broadcast all inputs combinations
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto a_input) {
const auto& ori_input = a_input;
// multibroadcast this input between every other input
std::for_each(inputs.cbegin(), inputs.cend(), [&](auto b_input) {
if(b_input != ori_input)
auto c_dyn_dims =
compute_broadcasted_dyn_dims(inputs[0]->get_shape(), inputs[1]->get_shape());
// following should work for a static or dynamic shape
// TODO: compute_broadcasted_dyn_dims() is going to be called again in the multibroadcast
// compute_shape should figure out a way to get around recomputing that. Attribute in
// multibroadcast?
if(inputs[0]->get_shape().dyn_dims() != c_dyn_dims)
{
a_input =
m.insert_instruction(ins, make_op("multibroadcast"), a_input, b_input);
inputs[0] = m.insert_instruction(ins, make_op("multibroadcast"), inputs[0], inputs[1]);
}
});
if(a_input->get_shape().type() != c_type)
if(inputs[1]->get_shape().dyn_dims() != c_dyn_dims)
{
inputs[1] = m.insert_instruction(ins, make_op("multibroadcast"), inputs[1], inputs[0]);
}
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().type() != c_type)
{
a_input = m.insert_instruction(
ins, make_op("convert", {{"target_type", c_type}}), a_input);
input =
m.insert_instruction(ins, make_op("convert", {{"target_type", c_type}}), input);
}
return a_input;
return input;
});
}
else
......
......@@ -37,9 +37,6 @@ struct operation;
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1);
std::vector<std::size_t> compute_broadcasted_opt_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1);
shape common_shape(const std::vector<shape>& shapes);
instruction_ref insert_common_op(module& m,
......
......@@ -28,6 +28,7 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -63,7 +64,15 @@ struct binary : op_name<Derived>
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(2).same_type().same_dims();
auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
if(s0 == s1 and s0.packed())
if(s0.dynamic() and s1.dynamic() and s0 == s1)
{
return s0;
}
else if(s0.dynamic() or s1.dynamic())
{
MIGRAPHX_THROW("BINARY: " + point_function() + ": fixed-dyn shape for inputs");
}
else if(s0 == s1 and s0.packed())
{
return s0;
}
......@@ -81,9 +90,9 @@ struct binary : op_name<Derived>
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
argument result{output_shape};
argument result{dyn_out.computed_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
std::transform(input1.begin(),
input1.end(),
......
......@@ -104,10 +104,51 @@ struct multibroadcast
if(s0.dynamic() and s1.dynamic())
{
// TODO handle both dynamic case
MIGRAPHX_THROW("MULTIBROADCAST_2IN: two dynamic shape inputs not handled.");
MIGRAPHX_THROW(
"MULTIBROADCAST_2IN: not handled; two dynamic shape inputs not handled");
}
else if(s0.dynamic() or s1.dynamic())
{
// only handles the case when broadcasting static shape to dynamic shape
// all the dimensions in the static shape must match to a fixed dimension in the
// dynamic shape or be 1
// TODO: handling the other possibilities
if(s1.dynamic())
{
std::swap(s0, s1);
}
auto static_rank = s1.lens().size();
auto dyn_rank = s0.max_lens().size();
if(static_rank > dyn_rank)
{
MIGRAPHX_THROW("MULTIBROADCAST_2IN: not handled; static shape has a higher "
"rank than dynamic shape");
}
return s0;
auto offset = dyn_rank - static_rank;
std::vector<shape::dynamic_dimension> out_dims(s0.dyn_dims());
std::transform(s0.dyn_dims().begin(),
s0.dyn_dims().end(),
s1.lens().begin() + offset,
out_lens.begin() + offset,
[&](auto a, auto b) {
if(a == b)
{
return a;
}
else if((a == 1 or b == 1) and a != 0 and b != 0)
{
return std::max(a, b);
}
else
{
// if not matching nor 1, set to 0
return static_cast<std::size_t>(0);
}
});
/*
auto bcast_min_lens = compute_broadcasted_lens(s0.min_lens(), s1.min_lens());
auto bcast_max_lens = compute_broadcasted_lens(s0.max_lens(), s1.max_lens());
auto bcast_opt_lens = compute_broadcasted_opt_lens(s0.opt_lens(), s1.opt_lens());
......@@ -115,6 +156,7 @@ struct multibroadcast
std::move(bcast_min_lens),
std::move(bcast_max_lens),
std::move(bcast_opt_lens)};
*/
}
else
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment