eliminate_contiguous.cpp 3.67 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
Paul's avatar
Paul committed
7
8
#include <migraphx/op/contiguous.hpp>
#include <migraphx/op/identity.hpp>
9
#include <migraphx/par_for.hpp>
Paul's avatar
Paul committed
10
#include <utility>
11

Paul's avatar
Paul committed
12
namespace migraphx {
Paul's avatar
Paul committed
13
inline namespace MIGRAPHX_INLINE_NS {
14

15
16
17
static bool try_compute_shape(instruction_ref ins,
                              const std::vector<shape>& inputs,
                              const std::vector<module_ref>& mods)
18
19
20
{
    try
    {
21
        shape new_shape = ins->get_operator().compute_shape(inputs, mods);
Shucai Xiao's avatar
Shucai Xiao committed
22
23
        // If the output shape is a standard shape, no need to try its output
        if(new_shape.standard())
24
25
26
27
        {
            return true;
        }

28
        // if no changes for the shape, the contiguous can also be removed
Shucai Xiao's avatar
Shucai Xiao committed
29
        if(new_shape == ins->get_shape())
30
31
32
33
        {
            return true;
        }

34
        auto outputs = ins->outputs();
35
        // If the current instruction has no output, it means it is the last
36
37
        // instruction and generates a non-standard output shape, and the last
        // output shape is different from the case with the contiguous operator
Shucai Xiao's avatar
Shucai Xiao committed
38
        if(outputs.empty())
39
        {
40
            return false;
41
42
        }

Shucai Xiao's avatar
Shucai Xiao committed
43
        for(auto output : outputs)
44
45
        {
            auto args = output->inputs();
46
47
48
49
            std::vector<shape> input_shapes(args.size());
            std::transform(args.begin(), args.end(), input_shapes.begin(), [&](auto& arg) {
                return (arg == ins) ? new_shape : arg->get_shape();
            });
Shucai Xiao's avatar
Shucai Xiao committed
50

51
            if(!try_compute_shape(output, input_shapes, mods))
52
53
54
55
            {
                return false;
            }
        }
56
57
58
59
60
    }
    catch(...)
    {
        return false;
    }
61

62
63
64
    return true;
}

65
66
67
static bool try_compute_shape(instruction_ref ins,
                              const std::vector<instruction_ref>& args,
                              const std::vector<module_ref>& mods)
68
69
{
    auto inputs = to_shapes(args);
70
    return try_compute_shape(ins, inputs, mods);
71
72
}

73
void eliminate_contiguous::apply(module& m) const
74
{
75
76
    std::vector<instruction_ref> const_instruction;

77
    for(auto ins : iterator_for(m))
78
    {
79
80
81
82
        // return instruction should have inputs with standard shape
        if(ins->name() == "@return")
            continue;

83
        // Make a copy so we can modify it while we iterate
Shucai Xiao's avatar
Shucai Xiao committed
84
85
86
        auto args     = ins->inputs();
        auto new_args = args;
        auto mod_args = ins->module_inputs();
87

Paul's avatar
Paul committed
88
        for(auto arg : ins->inputs())
89
        {
90
            if(arg->name() == op_name)
91
            {
Shucai Xiao's avatar
Shucai Xiao committed
92
                auto prev = arg->inputs().front();
93
                replace(new_args, arg, prev);
Shucai Xiao's avatar
Shucai Xiao committed
94
                if(try_compute_shape(ins, new_args, mod_args))
95
                {
Paul's avatar
Paul committed
96
                    instruction::replace_argument(ins, arg, prev);
97
                }
Paul's avatar
Paul committed
98
                else if(prev->can_eval())
Paul's avatar
Paul committed
99
                {
100
                    const_instruction.push_back(arg);
Paul's avatar
Paul committed
101
                }
102
103
104
            }
        }
    }
105
106
107
108
109
110
111
112
113
114
115
116
117
118

    // Perform evaluations in parallel
    std::vector<argument> literals(const_instruction.size());
    par_for(const_instruction.size(), 1, [&](const auto i) {
        auto c      = op::contiguous{};
        auto prev   = const_instruction[i]->inputs().front();
        literals[i] = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()});
    });

    for(size_t i = 0; i < const_instruction.size(); i++)
    {
        auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
        m.replace_instruction(const_instruction[i], l);
    }
119
120
}

Paul's avatar
Paul committed
121
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
122
} // namespace migraphx