eliminate_contiguous.cpp 1.3 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
Paul's avatar
Paul committed
7
#include <utility>
8

Paul's avatar
Paul committed
9
namespace migraphx {
Paul's avatar
Paul committed
10
inline namespace MIGRAPHX_INLINE_NS {
11

Paul's avatar
Paul committed
12
bool try_compute_shape(const operation& op, const std::vector<instruction_ref>& args)
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
{
    try
    {
        compute_shape(op, args);
    }
    catch(...)
    {
        return false;
    }
    return true;
}

void eliminate_contiguous::apply(program& p) const
{
    for(auto ins : iterator_for(p))
    {
        // Make a copy so we can modify it while we iterate
Paul's avatar
Paul committed
30
31
        auto args = ins->inputs();
        for(auto arg : ins->inputs())
32
33
34
        {
            // TODO: Pass in names for the operator in the constructor instead
            // of using ends_with
Paul's avatar
Paul committed
35
            if(ends_with(arg->name(), "contiguous"))
36
37
            {
                auto new_args = args;
Paul's avatar
Paul committed
38
                auto prev     = arg->inputs().front();
39
                replace(new_args, arg, prev);
40
                if(try_compute_shape(ins->get_operator(), new_args))
41
                {
Paul's avatar
Paul committed
42
                    instruction::replace_argument(ins, arg, prev);
43
44
45
46
47
48
                }
            }
        }
    }
}

Paul's avatar
Paul committed
49
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
50
} // namespace migraphx