eliminate_concat.cpp 3.74 KB
Newer Older
1
#include <iterator>
Paul's avatar
Paul committed
2
3
4
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
5
6
#include <migraphx/op/load.hpp>
#include <migraphx/op/identity.hpp>
Paul's avatar
Paul committed
7
#include <migraphx/iterator_for.hpp>
8
#include <migraphx/ranges.hpp>
9
10
#include <migraphx/make_op.hpp>

Paul's avatar
Paul committed
11
#include <migraphx/dfor.hpp>
12
#include <migraphx/tune_axis.hpp>
13

Paul's avatar
Paul committed
14
namespace migraphx {
Paul's avatar
Paul committed
15
inline namespace MIGRAPHX_INLINE_NS {
16
void eliminate_concat::apply(module& p) const
17
18
19
20
21
22
{
    for(auto ins : iterator_for(p))
    {
        // Look for the concat operator
        if(ins->name() != concat_opt.name())
            continue;
23
24
25
26
27
28
29
30
        // If any inputs are builtin or context free then abort
        // If any inputs are used more than once, then abort since there could
        // be errors due to aliasing
        if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto arg) {
               return arg->name().front() == '@' or
                      (arg->get_operator().is_context_free() and
                       not contains({"concat", "identity"}, arg->name())) or
                      arg->outputs().size() > 1;
31
32
           }))
            continue;
Scott Thornton's avatar
Scott Thornton committed
33
        // We can only do this optimization when concat axis is either the leftmost
34
35
36
        // axis OR the sizes to the left of this axis are all equal to 1
        // Since we've already checked that the non-axis dimensions are identical
        // we only need to check the first input
37
38
39
        auto lens              = ins->inputs().front()->get_shape().lens();
        auto concat_op         = concat_opt.get_concat(ins->get_operator());
        std::size_t axis_index = tune_axis(lens.size(), concat_op.axis, concat_op.name());
Shucai Xiao's avatar
Shucai Xiao committed
40
        if(axis_index == 0 ||
41
           std::all_of(lens.begin(), lens.begin() + axis_index, [](auto x) { return x == 1; }))
42
43
44
        {
            // Last input should be an allocation
            auto last = ins->inputs().back();
Scott Thornton's avatar
Scott Thornton committed
45
46
            if(last->name() != concat_opt.allocate())
                continue;
47
48
49
            // Where are the allocations for the tensors to be concatenated?
            std::vector<instruction_ref> allocations;

Paul's avatar
Paul committed
50
51
52
53
54
            std::transform(
                ins->inputs().begin(),
                std::prev(ins->inputs().end()),
                std::back_inserter(allocations),
                [&](instruction_ref x) { return instruction::get_output_alias(x, true); });
Paul's avatar
Paul committed
55

Paul's avatar
Paul committed
56
57
58
            if(std::any_of(allocations.begin(), allocations.end(), [&](auto x) {
                   return x->name() != concat_opt.allocate();
               }))
Paul's avatar
Paul committed
59
60
                continue;

Scott Thornton's avatar
Scott Thornton committed
61
            // Need to sort the allocations, so that we know where to
62
            // insert the "super"-allocation
63
64
65
66
67
68
            auto sorted_allocations = allocations;
            std::sort(sorted_allocations.begin(),
                      sorted_allocations.end(),
                      [&](instruction_ref x, instruction_ref y) {
                          return std::distance(p.begin(), x) < std::distance(p.begin(), y);
                      });
69
            // Move "super" allocation to the front
70
            auto first = sorted_allocations.front();
Paul's avatar
Paul committed
71
            auto super = p.move_instruction(last, first);
Paul's avatar
Paul committed
72
            // Replace each allocation with a load
73
            std::size_t offset = 0;
Paul's avatar
Paul committed
74
            for(auto alloc : allocations)
75
            {
Paul's avatar
Paul committed
76
77
78
                op::load op{alloc->get_shape(), offset};
                p.replace_instruction(alloc, op, {super});
                offset += alloc->get_shape().bytes();
79
80
            }
            std::vector<instruction_ref> args = {super};
Scott Thornton's avatar
Scott Thornton committed
81
            std::copy(ins->inputs().begin(), ins->inputs().end() - 1, std::back_inserter(args));
82
            p.replace_instruction(ins, migraphx::make_op("identity"), args);
83
84
85
        }
    }
}
Paul's avatar
Paul committed
86
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
87
} // namespace migraphx