eliminate_contiguous.cpp 3.71 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
Paul's avatar
Paul committed
7
8
#include <migraphx/op/contiguous.hpp>
#include <migraphx/op/identity.hpp>
Paul's avatar
Paul committed
9
#include <utility>
10

Paul's avatar
Paul committed
11
namespace migraphx {
Paul's avatar
Paul committed
12
inline namespace MIGRAPHX_INLINE_NS {
13

14
15
16
static bool try_compute_shape(instruction_ref ins,
                              const std::vector<shape>& inputs,
                              const std::vector<module_ref>& mods)
17
18
19
{
    try
    {
20
        shape new_shape = ins->get_operator().compute_shape(inputs, mods);
Shucai Xiao's avatar
Shucai Xiao committed
21
22
        // If the output shape is a standard shape, no need to try its output
        if(new_shape.standard())
23
24
25
26
        {
            return true;
        }

27
        // if no changes for the shape, the contiguous can also be removed
Shucai Xiao's avatar
Shucai Xiao committed
28
        if(new_shape == ins->get_shape())
29
30
31
32
        {
            return true;
        }

33
        auto outputs = ins->outputs();
34
        // If the current instruction has no output, it means it is the last
35
36
        // instruction and generates a non-standard output shape, and the last
        // output shape is different from the case with the contiguous operator
Shucai Xiao's avatar
Shucai Xiao committed
37
        if(outputs.empty())
38
        {
39
            return false;
40
41
        }

Shucai Xiao's avatar
Shucai Xiao committed
42
        for(auto output : outputs)
43
44
        {
            auto args = output->inputs();
45
46
47
48
            std::vector<shape> input_shapes(args.size());
            std::transform(args.begin(), args.end(), input_shapes.begin(), [&](auto& arg) {
                return (arg == ins) ? new_shape : arg->get_shape();
            });
Shucai Xiao's avatar
Shucai Xiao committed
49

50
            if(!try_compute_shape(output, input_shapes, mods))
51
52
53
54
            {
                return false;
            }
        }
55
56
57
58
59
    }
    catch(...)
    {
        return false;
    }
60

61
62
63
    return true;
}

64
65
66
static bool try_compute_shape(instruction_ref ins,
                              const std::vector<instruction_ref>& args,
                              const std::vector<module_ref>& mods)
67
68
{
    auto inputs = to_shapes(args);
69
    return try_compute_shape(ins, inputs, mods);
70
71
}

Paul's avatar
Paul committed
72
73
template<class F>
static void remove_contiguous(const std::string& op_name, module& m, F f)
74
{
Paul's avatar
Paul committed
75
    auto last = std::prev(m.end());
76
    for(auto ins : iterator_for(m))
77
    {
78
79
80
81
        // return instruction should have inputs with standard shape
        if(ins->name() == "@return")
            continue;

Paul's avatar
Paul committed
82
83
84
85
86
87
        if (ins != last and ins->outputs().empty())
            continue;

        if (not f(ins))
            continue;

88
        // Make a copy so we can modify it while we iterate
Shucai Xiao's avatar
Shucai Xiao committed
89
90
91
        auto args     = ins->inputs();
        auto new_args = args;
        auto mod_args = ins->module_inputs();
Paul's avatar
Paul committed
92
        for(auto arg : ins->inputs())
93
        {
Paul's avatar
Paul committed
94
95
96
97
98
99
100
101
102
            if(arg->name() != op_name)
                continue;
            auto prev = arg->inputs().front();
            replace(new_args, arg, prev);
            if(try_compute_shape(ins, new_args, mod_args))
            {
                instruction::replace_argument(ins, arg, prev);
            }
            else if(prev->can_eval())
103
            {
Paul's avatar
Paul committed
104
105
                auto c = op::contiguous{};
                auto r = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()});
Paul's avatar
Paul committed
106

Paul's avatar
Paul committed
107
108
                auto l = m.add_literal(r.get_shape(), r.data());
                m.replace_instruction(arg, l);
109
110
111
112
113
            }
        }
    }
}

Paul's avatar
Paul committed
114
115
116
117
118
119
120
121
122
123
124
void eliminate_contiguous::apply(module& m) const
{
    // Skip contiguous from splits first
    remove_contiguous(op_name, m, [](auto ins) { 
        if (ins->name() != "slice")
            return true;
        return (ins->inputs().front()->outputs().size() == 1);
    });
    remove_contiguous(op_name, m, [](auto) { return true; });
}

Paul's avatar
Paul committed
125
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
126
} // namespace migraphx