eliminate_contiguous.cpp 4.44 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
Paul's avatar
Paul committed
7
8
#include <migraphx/op/contiguous.hpp>
#include <migraphx/op/identity.hpp>
9
#include <migraphx/par_for.hpp>
Paul's avatar
Paul committed
10
#include <utility>
11

Paul's avatar
Paul committed
12
namespace migraphx {
Paul's avatar
Paul committed
13
inline namespace MIGRAPHX_INLINE_NS {
14

15
16
17
static bool try_compute_shape(instruction_ref ins,
                              const std::vector<shape>& inputs,
                              const std::vector<module_ref>& mods)
18
19
20
{
    try
    {
21
        shape new_shape = ins->get_operator().compute_shape(inputs, mods);
Shucai Xiao's avatar
Shucai Xiao committed
22
23
        // If the output shape is a standard shape, no need to try its output
        if(new_shape.standard())
24
25
26
27
        {
            return true;
        }

28
        // if no changes for the shape, the contiguous can also be removed
Shucai Xiao's avatar
Shucai Xiao committed
29
        if(new_shape == ins->get_shape())
30
31
32
33
        {
            return true;
        }

34
        auto outputs = ins->outputs();
35
        // If the current instruction has no output, it means it is the last
36
37
        // instruction and generates a non-standard output shape, and the last
        // output shape is different from the case with the contiguous operator
Shucai Xiao's avatar
Shucai Xiao committed
38
        if(outputs.empty())
39
        {
40
            return false;
41
42
        }

Shucai Xiao's avatar
Shucai Xiao committed
43
        for(auto output : outputs)
44
45
        {
            auto args = output->inputs();
46
47
48
49
            std::vector<shape> input_shapes(args.size());
            std::transform(args.begin(), args.end(), input_shapes.begin(), [&](auto& arg) {
                return (arg == ins) ? new_shape : arg->get_shape();
            });
Shucai Xiao's avatar
Shucai Xiao committed
50

51
            if(!try_compute_shape(output, input_shapes, mods))
52
53
54
55
            {
                return false;
            }
        }
56
57
58
59
60
    }
    catch(...)
    {
        return false;
    }
61

62
63
64
    return true;
}

65
66
67
static bool try_compute_shape(instruction_ref ins,
                              const std::vector<instruction_ref>& args,
                              const std::vector<module_ref>& mods)
68
69
{
    auto inputs = to_shapes(args);
70
    return try_compute_shape(ins, inputs, mods);
71
72
}

Paul's avatar
Format  
Paul committed
73
template <class F>
Paul's avatar
Paul committed
74
static void remove_contiguous(const std::string& op_name, module& m, F f)
75
{
Paul's avatar
Paul committed
76
    auto last = std::prev(m.end());
77
78
    std::vector<instruction_ref> const_instruction;

79
    for(auto ins : iterator_for(m))
80
    {
81
82
83
84
        // return instruction should have inputs with standard shape
        if(ins->name() == "@return")
            continue;

Paul's avatar
Format  
Paul committed
85
        if(ins != last and ins->outputs().empty())
Paul's avatar
Paul committed
86
87
            continue;

Paul's avatar
Format  
Paul committed
88
        if(not f(ins))
Paul's avatar
Paul committed
89
90
            continue;

91
        // Make a copy so we can modify it while we iterate
Shucai Xiao's avatar
Shucai Xiao committed
92
93
94
        auto args     = ins->inputs();
        auto new_args = args;
        auto mod_args = ins->module_inputs();
95

Paul's avatar
Paul committed
96
        for(auto arg : ins->inputs())
97
        {
Paul's avatar
Paul committed
98
99
100
101
102
103
104
105
106
            if(arg->name() != op_name)
                continue;
            auto prev = arg->inputs().front();
            replace(new_args, arg, prev);
            if(try_compute_shape(ins, new_args, mod_args))
            {
                instruction::replace_argument(ins, arg, prev);
            }
            else if(prev->can_eval())
107
108
            {
                replace(new_args, arg, prev);
Shucai Xiao's avatar
Shucai Xiao committed
109
                if(try_compute_shape(ins, new_args, mod_args))
110
                {
Paul's avatar
Paul committed
111
                    instruction::replace_argument(ins, arg, prev);
112
                }
Paul's avatar
Paul committed
113
                else if(prev->can_eval())
Paul's avatar
Paul committed
114
                {
115
                    const_instruction.push_back(arg);
Paul's avatar
Paul committed
116
                }
117
118
119
            }
        }
    }
120
121
122
123
124
125
126
127
128
129
130
131
132
133

    // Perform evaluations in parallel
    std::vector<argument> literals(const_instruction.size());
    par_for(const_instruction.size(), 1, [&](const auto i) {
        auto c      = op::contiguous{};
        auto prev   = const_instruction[i]->inputs().front();
        literals[i] = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()});
    });

    for(size_t i = 0; i < const_instruction.size(); i++)
    {
        auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
        m.replace_instruction(const_instruction[i], l);
    }
134
135
}

Paul's avatar
Paul committed
136
137
138
void eliminate_contiguous::apply(module& m) const
{
    // Skip contiguous from splits first
Paul's avatar
Format  
Paul committed
139
140
    remove_contiguous(op_name, m, [](auto ins) {
        if(ins->name() != "slice")
Paul's avatar
Paul committed
141
142
143
144
145
146
            return true;
        return (ins->inputs().front()->outputs().size() == 1);
    });
    remove_contiguous(op_name, m, [](auto) { return true; });
}

Paul's avatar
Paul committed
147
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
148
} // namespace migraphx