Commit e801e2f7 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

manual merge

parents e2ec9378 aa56068c
......@@ -26,7 +26,6 @@
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/iterator_for.hpp>
namespace migraphx {
......@@ -62,8 +61,15 @@ void auto_contiguous::apply(module& m) const
{
if(contains({"layout", "contiguous", "@return", "@param", "@outline"}, ins->name()))
continue;
auto outputs = ins->outputs();
// for last instruction that is NOT a return
if(ins->outputs().empty() and ins != last)
if(outputs.empty() and ins != last)
continue;
if(not outputs.empty())
// if contiguous was already inserted, skip
if(std::all_of(outputs.begin(), outputs.end(), [](auto output) {
return output->name() == "contiguous";
}))
continue;
shape s = ins->get_shape();
if(s.dynamic())
......@@ -73,9 +79,8 @@ void auto_contiguous::apply(module& m) const
if(s.standard() and ins->name() == "@literal")
continue;
if(s.scalar() and not contains(ins->name(), "broadcast"))
{
continue;
}
auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins);
m.replace_instruction(ins, c);
}
......
......@@ -40,7 +40,7 @@ bool skip_propagate(instruction_ref ins)
if(ins->name() == "contiguous")
return skip_propagate(ins->inputs().front());
auto&& s = ins->get_shape();
if(s.broadcasted() and not s.scalar())
if(s.broadcasted() and not s.scalar() and not s.packed())
return true;
if(s.scalar() and s.elements() != 1)
return true;
......@@ -101,9 +101,11 @@ void propagate_constant::apply(module& m) const
})(const_instrs_vec[i]);
m.debug_print(inss);
}
assert(literals[i].get_shape() == const_instrs_vec[i]->get_shape());
auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
m.replace_instruction(const_instrs_vec[i], l);
auto in_shape = const_instrs_vec[i]->get_shape();
assert(literals[i].get_shape() == in_shape);
literal l{in_shape, literals[i].data()};
auto l0 = m.add_literal(l);
m.replace_instruction(const_instrs_vec[i], l0);
}
}
}
......
......@@ -564,6 +564,17 @@ struct find_inner_broadcast
return 3;
}));
auto op = insert_common_op(m, ins, ins->get_operator(), inputs);
std::vector<shape> broadcast_shapes;
std::transform(broadcasts.begin(), broadcasts.end(), std::back_inserter(broadcast_shapes), [](auto broadcast){
return broadcast->get_shape();
});
std::vector<shape> common_shapes;
std::transform(op->inputs().begin(), op->inputs().end(), std::back_inserter(common_shapes), [](auto common){
return common->get_shape();
});
if(broadcast_shapes == common_shapes and std::all_of(op->inputs().begin(), op->inputs().end(), [](auto i){
return i->name() == "broadcast" or i->name() == "multibroadcast";}))
return;
m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
}
};
......
......@@ -179,7 +179,8 @@ TEST_CASE(standard_reshape_lazy)
auto ca = m2.add_instruction(migraphx::make_op("contiguous"), add);
auto r =
m2.add_instruction(migraphx::make_op("reshape_lazy", {{"dims", {2, 1, 12, 5}}}), ca);
m2.add_return({r});
auto cr = m2.add_instruction(migraphx::make_op("contiguous"), r);
m2.add_return({cr});
}
EXPECT(m1 == m2);
......@@ -201,9 +202,7 @@ TEST_CASE(standard_reshape)
auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto add = m2.add_instruction(migraphx::make_op("add"), data, data);
auto ca = m2.add_instruction(migraphx::make_op("contiguous"), add);
// extra contiguous coming from reshape logic which has "requires_std_shape" attribute
auto cb = m2.add_instruction(migraphx::make_op("contiguous"), ca);
auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 12, 5}}}), cb);
auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 12, 5}}}), ca);
auto cr = m2.add_instruction(migraphx::make_op("contiguous"), r);
m2.add_return({cr});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment