eliminate_contiguous.cpp 3.84 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
Paul's avatar
Paul committed
7
8
#include <migraphx/op/contiguous.hpp>
#include <migraphx/op/identity.hpp>
Paul's avatar
Paul committed
9
#include <utility>
10

Paul's avatar
Paul committed
11
namespace migraphx {
Paul's avatar
Paul committed
12
inline namespace MIGRAPHX_INLINE_NS {
13

14
static bool try_compute_shape(instruction_ref ins, const std::vector<shape>& inputs)
15
16
17
{
    try
    {
18
        shape new_shape = ins->get_operator().compute_shape(inputs);
Shucai Xiao's avatar
Shucai Xiao committed
19
20
        // If the output shape is a standard shape, no need to try its output
        if(new_shape.standard())
21
22
23
24
        {
            return true;
        }

25
        // if no changes for the shape, the contiguous can also be removed
Shucai Xiao's avatar
Shucai Xiao committed
26
        if(new_shape == ins->get_shape())
27
28
29
30
        {
            return true;
        }

31
        auto outputs = ins->outputs();
32
        // If the current instruction has no output, it means it is the last
33
34
        // instruction and generates a non-standard output shape, and the last
        // output shape is different from the case with the contiguous operator
Shucai Xiao's avatar
Shucai Xiao committed
35
        if(outputs.empty())
36
        {
37
            return false;
38
39
        }

Shucai Xiao's avatar
Shucai Xiao committed
40
        for(auto output : outputs)
41
42
        {
            auto args = output->inputs();
43
44
45
46
            std::vector<shape> input_shapes(args.size());
            std::transform(args.begin(), args.end(), input_shapes.begin(), [&](auto& arg) {
                return (arg == ins) ? new_shape : arg->get_shape();
            });
Shucai Xiao's avatar
Shucai Xiao committed
47
48

            if(!try_compute_shape(output, input_shapes))
49
50
51
52
            {
                return false;
            }
        }
53
54
55
56
57
    }
    catch(...)
    {
        return false;
    }
58

59
60
61
    return true;
}

62
static bool try_compute_shape(instruction_ref ins, const std::vector<instruction_ref>& args)
63
64
65
66
67
{
    auto inputs = to_shapes(args);
    return try_compute_shape(ins, inputs);
}

68
void eliminate_contiguous::apply(module& p) const
69
70
71
{
    for(auto ins : iterator_for(p))
    {
72
73
74
75
        // return instruction should have inputs with standard shape
        if(ins->name() == "@return")
            continue;

Paul's avatar
Paul committed
76
77
78
        if(std::none_of(ins->inputs().begin(), ins->inputs().end(), [&](auto arg) {
               return arg->name() == op_name;
           }))
79
            continue;
80
        // Make a copy so we can modify it while we iterate
Paul's avatar
Paul committed
81
        auto args     = ins->inputs();
82
83
        auto new_args = args;
        std::transform(new_args.begin(), new_args.end(), new_args.begin(), [&](auto arg) {
84
            if(arg->name() == op_name)
85
86
87
88
89
90
91
92
                return arg->inputs().front();
            else
                return arg;
        });
        assert(args.size() == new_args.size());

        if(try_compute_shape(ins, new_args))
        {
Paul's avatar
Paul committed
93
            for(auto i : range(args.size()))
94
            {
Paul's avatar
Paul committed
95
                if(args[i] == new_args[i])
96
97
98
99
100
101
102
103
104
                    continue;
                instruction::replace_argument(ins, args[i], new_args[i]);
            }
        }
        else
        {
            for(auto arg : ins->inputs())
            {
                if(arg->name() == op_name)
Paul's avatar
Paul committed
105
                {
Paul's avatar
Paul committed
106
107
                    new_args  = args;
                    auto prev = arg->inputs().front();
108
109
110
111
112
113
114
115
116
                    replace(new_args, arg, prev);
                    if(try_compute_shape(ins, new_args))
                    {
                        instruction::replace_argument(ins, arg, prev);
                    }
                    else if(prev->can_eval())
                    {
                        auto c = op::contiguous{};
                        auto r = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()});
Paul's avatar
Paul committed
117

118
119
120
                        auto l = p.add_literal(r.get_shape(), r.data());
                        p.replace_instruction(arg, l);
                    }
Paul's avatar
Paul committed
121
                }
122
123
124
125
126
            }
        }
    }
}

Paul's avatar
Paul committed
127
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
128
} // namespace migraphx