eliminate_contiguous.cpp 1.32 KB
Newer Older
1
2
3
4
5
6
7
#include <migraph/eliminate_contiguous.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/ranges.hpp>
#include <migraph/stringutils.hpp>
Paul's avatar
Paul committed
8
#include <utility>
9
10

namespace migraph {
11
inline namespace MIGRAPH_INLINE_NS {
12

Paul's avatar
Paul committed
13
bool try_compute_shape(const operation& op, const std::vector<instruction_ref>& args)
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
{
    try
    {
        compute_shape(op, args);
    }
    catch(...)
    {
        return false;
    }
    return true;
}

void eliminate_contiguous::apply(program& p) const
{
    for(auto ins : iterator_for(p))
    {
        // Make a copy so we can modify it while we iterate
Paul's avatar
Paul committed
31
32
        auto args = ins->inputs();
        for(auto arg : ins->inputs())
33
34
35
        {
            // TODO: Pass in names for the operator in the constructor instead
            // of using ends_with
Paul's avatar
Paul committed
36
            if(ends_with(arg->name(), "contiguous"))
37
38
            {
                auto new_args = args;
Paul's avatar
Paul committed
39
                auto prev     = arg->inputs().front();
40
                replace(new_args, arg, prev);
41
                if(try_compute_shape(ins->get_operator(), new_args))
42
                {
Paul's avatar
Paul committed
43
                    instruction::replace_argument(ins, arg, prev);
44
45
46
47
48
49
                }
            }
        }
    }
}

50
} // namespace MIGRAPH_INLINE_NS
51
} // namespace migraph