eliminate_contiguous.cpp 2.64 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
Paul's avatar
Paul committed
7
#include <utility>
8

Paul's avatar
Paul committed
9
namespace migraphx {
Paul's avatar
Paul committed
10
inline namespace MIGRAPHX_INLINE_NS {
11

12
static bool try_compute_shape(instruction_ref ins, const std::vector<shape>& inputs)
13
14
15
{
    try
    {
16
        shape new_shape = ins->get_operator().compute_shape(inputs);
Shucai Xiao's avatar
Shucai Xiao committed
17
18
        // If the output shape is a standard shape, no need to try its output
        if(new_shape.standard())
19
20
21
22
23
        {
            return true;
        }

        auto outputs = ins->outputs();
24
25
26
27
        // If the current instruction has no output, it means it is the last
        // instruction and generates a non-standard output. But for unary
        // and binary operators, we can still remove it and reshape the output
        // to be standard since these operator can handle non-standard inputs
Shucai Xiao's avatar
Shucai Xiao committed
28
        if(outputs.empty())
29
        {
30
            return true;
31
32
        }

Shucai Xiao's avatar
Shucai Xiao committed
33
        for(auto output : outputs)
34
35
36
        {
            auto args = output->inputs();
            std::vector<shape> input_shapes;
Shucai Xiao's avatar
Shucai Xiao committed
37
            for(auto arg : args)
38
39
40
            {
                input_shapes.push_back((arg == ins) ? new_shape : arg->get_shape());
            }
Shucai Xiao's avatar
Shucai Xiao committed
41
42

            if(!try_compute_shape(output, input_shapes))
43
44
45
46
            {
                return false;
            }
        }
47
48
49
50
51
    }
    catch(...)
    {
        return false;
    }
52

53
54
55
    return true;
}

56
static bool try_compute_shape(instruction_ref ins, const std::vector<instruction_ref>& args)
57
58
59
60
61
{
    auto inputs = to_shapes(args);
    return try_compute_shape(ins, inputs);
}

62
63
64
65
void eliminate_contiguous::apply(program& p) const
{
    for(auto ins : iterator_for(p))
    {
66
67
        // skip the reshape operator for now, since there is a bug
        // for the transpose followed by a reshape
Shucai Xiao's avatar
Shucai Xiao committed
68
        if(ins->name() == "reshape")
69
70
71
72
        {
            continue;
        }

73
        // Make a copy so we can modify it while we iterate
Paul's avatar
Paul committed
74
75
        auto args = ins->inputs();
        for(auto arg : ins->inputs())
76
77
78
        {
            // TODO: Pass in names for the operator in the constructor instead
            // of using ends_with
Paul's avatar
Paul committed
79
            if(ends_with(arg->name(), "contiguous"))
80
81
            {
                auto new_args = args;
Paul's avatar
Paul committed
82
                auto prev     = arg->inputs().front();
83
                replace(new_args, arg, prev);
84
                if(try_compute_shape(ins, new_args))
85
                {
Paul's avatar
Paul committed
86
                    instruction::replace_argument(ins, arg, prev);
87
88
89
90
91
92
                }
            }
        }
    }
}

Paul's avatar
Paul committed
93
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
94
} // namespace migraphx