layout_nhwc.cpp 2.11 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/module.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
Paul's avatar
Paul committed
5
6
#include <migraphx/permutation.hpp>
#include <migraphx/functional.hpp>
Paul's avatar
Paul committed
7
8
9
10
11
12
13
14
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

Paul's avatar
Paul committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
template<class Predicate>
std::vector<instruction_ref> find_lasts(const module& m, Predicate pred)
{
    std::vector<instruction_ref> result;
    fix([&](auto self, auto ins) {
        if (pred(ins))
        {
            result.push_back(ins);
            return;
        }
        for(auto input:ins->inputs())
            self(input);
    })(std::prev(m.end()));
    return result;
}

void preserve_output_layout(module& m)
{
    std::vector<instruction_ref> outputs = find_lasts(m, [](auto ins) {
        return ins->get_shape().lens().size() == 4;
    });
    for(auto output:outputs)
    {
        auto permutation = find_permutation(output->get_shape());
        auto layout = m.insert_instruction(std::next(output), make_op("layout", {{"permutation", permutation}}), output);
        m.replace_instruction(output, layout);
    }
}


Paul's avatar
Paul committed
45
46
47
48
void transform_convolutions(module& m)
{
    for(auto ins : iterator_for(m))
    {
Paul's avatar
Paul committed
49
        if(ins->name() != "convolution")
Paul's avatar
Paul committed
50
            continue;
Paul's avatar
Paul committed
51
        if(ins->get_shape().lens().size() != 4)
Paul's avatar
Paul committed
52
53
54
55
56
57
            continue;
        auto args = ins->inputs();
        std::transform(args.begin(), args.end(), args.begin(), [&](auto& i) {
            return m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), i);
        });
        auto conv = m.insert_instruction(ins, ins->get_operator(), args);
Paul's avatar
Paul committed
58
        auto c    = m.insert_instruction(ins, make_op("contiguous"), conv);
Paul's avatar
Paul committed
59
60
61
62
63
64
        m.replace_instruction(ins, c);
    }
}

void layout_nhwc::apply(module& m) const
{
Paul's avatar
Paul committed
65
    preserve_output_layout(m);
Paul's avatar
Paul committed
66
67
68
69
70
71
72
    transform_convolutions(m);
    dead_code_elimination{}.apply(m);
    eliminate_contiguous{"contiguous"}.apply(m);
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx