Commit 2929d51c authored by charlie's avatar charlie
Browse files

only handle dyn-fixed and fixed-fixed for now

parent 940220c8
......@@ -95,50 +95,6 @@ std::vector<std::size_t> compute_broadcasted_opt_lens(std::vector<std::size_t> s
return out_lens;
}
std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, shape s1)
{
if(s0 == s1)
return s0.dyn_dims();
if(not s0.dynamic() or not s1_dynamic())
{
// mixed fixed and dynamic
if(s0.dynamic())
s0.swap(s1);
}
else
{
// both dynamic
if(s0.dyn_dims().size() > s1.dyn_dims().size())
std::swap(s0, s1);
std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims());
auto offset = s1.size() - s0.size();
std::vector<shape::dynamic_dimension> one_dyn_dims{{1, 1, 0}, {1, 1, 1}};
std::transform(s0.begin(),
s0.end(),
s1.begin() + offset,
out_dims.begin() + offset,
[&](auto a, auto b) {
if(a == b)
{
return a;
}
else if(not contains(one_dyn_dims, a) and not contains(one_dyn_dims, b))
{
MIGRAPHX_THROW("COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {" +
migraphx::to_string_range(s0) + "} and {" +
migraphx::to_string_range(s1) + "} mismatch!");
}
else
{
return shape::dynamic_dimension{std::max(a.min, b.min),
std::max(a.max, b.max),
(a.opt != b.opt) ? 0 : a.opt};
}
});
return out_dims;
}
}
// Compute the common (broadcasted) dimensions of a list of fixed shapes
std::vector<std::size_t> compute_common_lens(const std::vector<shape>& shapes)
{
......@@ -194,7 +150,7 @@ instruction_ref insert_common_op(module& m,
inputs.cbegin(), inputs.cend(), [](auto input) { return input->get_shape().dynamic(); }))
{
auto c_type = compute_common_types(to_shapes(inputs));
// broadcast all inputs permutations
// broadcast all inputs combinations
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto a_input) {
const auto& ori_input = a_input;
// multibroadcast this input between every other input
......
......@@ -101,7 +101,12 @@ struct multibroadcast
{
// two inputs
auto s1 = inputs.at(1);
if(s0.dynamic() or s1.dynamic())
if(s0.dynamic() and s1.dynamic())
{
// TODO handle both dynamic case
MIGRAPHX_THROW("MULTIBROADCAST_2IN: two dynamic shape inputs not handled.");
}
else if(s0.dynamic() or s1.dynamic())
{
auto bcast_min_lens = compute_broadcasted_lens(s0.min_lens(), s1.min_lens());
auto bcast_max_lens = compute_broadcasted_lens(s0.max_lens(), s1.max_lens());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment