eliminate_contiguous.cpp 1.23 KB
Newer Older
1
2
3
4
5
6
7
#include <migraph/eliminate_contiguous.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/ranges.hpp>
#include <migraph/stringutils.hpp>
Paul's avatar
Paul committed
8
#include <utility>
9
10
11

namespace migraph {

Paul's avatar
Paul committed
12
bool try_compute_shape(const operation& op, const std::vector<instruction_ref>& args)
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
{
    try
    {
        compute_shape(op, args);
    }
    catch(...)
    {
        return false;
    }
    return true;
}

void eliminate_contiguous::apply(program& p) const
{
    for(auto ins : iterator_for(p))
    {
        // Make a copy so we can modify it while we iterate
        auto args = ins->arguments;
        for(auto arg : ins->arguments)
        {
            // TODO: Pass in names for the operator in the constructor instead
            // of using ends_with
Paul's avatar
Paul committed
35
            if(ends_with(arg->name(), "contiguous"))
36
37
            {
                auto new_args = args;
Paul's avatar
Paul committed
38
                auto prev     = arg->arguments.front();
39
                replace(new_args, arg, prev);
Paul's avatar
Paul committed
40
                if(try_compute_shape(ins->op, new_args))
41
42
43
44
45
46
47
48
49
                {
                    replace_argument(ins, arg, prev);
                }
            }
        }
    }
}

} // namespace migraph