Commit 834bb1bb authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

auto_contiguous always add contiguous after reshapes

These will get cleaned up later but result in us adding contiguous after
every reshape prior to us performing a find_reshape_alias matcher
parent 210ea72d
...@@ -25,12 +25,31 @@ ...@@ -25,12 +25,31 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
const auto& reshaper_op_names()
{
// clang-format off
static const std::unordered_set<std::string> names = {
"flatten",
"reshape",
"contiguous",
"squeeze",
"unsqueeze"
"transpose",
"multibroadcast",
"broadcast"
};
// clang-format on
return names;
}
bool is_reshaper_op(instruction_ref ins) { return contains(reshaper_op_names(), ins->name()); }
void auto_contiguous::apply(module& m) const void auto_contiguous::apply(module& m) const
{ {
std::string key = "require_std_shape"; std::string key = "require_std_shape";
...@@ -64,6 +83,14 @@ void auto_contiguous::apply(module& m) const ...@@ -64,6 +83,14 @@ void auto_contiguous::apply(module& m) const
// for last instruction that is NOT a return // for last instruction that is NOT a return
if(ins->outputs().empty() and ins != last) if(ins->outputs().empty() and ins != last)
continue; continue;
// perform a pass to insert contiguous for every reshape (without reshaper) before
// determining if aliasing can be performed
if(ins->name() == "reshape" and not is_reshaper_op(std::next(ins)))
{
m.insert_instruction(std::next(ins), make_op("contiguous"), ins);
}
shape s = ins->get_shape(); shape s = ins->get_shape();
if(not s.dynamic() and not s.standard() and s.elements() != 0) if(not s.dynamic() and not s.standard() and s.elements() != 0)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment