fuse_concat.cpp 5.9 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#include <migraphx/fuse_concat.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/module.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

struct fused_concat
{
    int64_t axis = 0;

    std::string name() const { return "fused_concat"; }

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.axis, "axis"));
    }

    shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
    {
        check_shapes{inputs, *this}.same_ndims();
Paul's avatar
Format  
Paul committed
30
        if((inputs.size() + 1) == mods.size())
Paul's avatar
Paul committed
31
32
33
            MIGRAPHX_THROW("FUSED_CONCAT: Missing fused modules");
        auto input_iter = inputs.begin();
        std::vector<shape> concat_inputs;
Paul's avatar
Format  
Paul committed
34
        for(module_ref mod : range(mods.begin(), mods.end() - 1))
Paul's avatar
Paul committed
35
36
37
38
        {
            concat_inputs.push_back(*input_iter);
            input_iter += mod->get_parameter_names().size();
        }
Paul's avatar
Format  
Paul committed
39
40
        module_ref post_mod          = mods.back();
        auto type                    = std::prev(post_mod->end())->get_shape().type();
Paul's avatar
Paul committed
41
        const auto& first_shape_lens = concat_inputs.front().lens();
Paul's avatar
Format  
Paul committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
        auto mismatch_it =
            std::find_if_not(concat_inputs.begin() + 1, concat_inputs.end(), [&](auto s) {
                const auto& lens = s.lens();
                return std::equal(lens.begin(),
                                  lens.begin() + axis,
                                  first_shape_lens.begin(),
                                  first_shape_lens.begin() + axis) and
                       std::equal(lens.begin() + axis + 1,
                                  lens.end(),
                                  first_shape_lens.begin() + axis + 1,
                                  first_shape_lens.end());
            });
        if(mismatch_it != concat_inputs.end())
            MIGRAPHX_THROW("FUSED_CONCAT: all input dimensions should match along non-axis of " +
                           std::to_string(axis) + ": {" + to_string_range(first_shape_lens) +
                           "} != {" + to_string_range(mismatch_it->lens()) + "}");
Paul's avatar
Format  
Paul committed
58
59
60
61
62
63

        std::size_t new_dim_axis = transform_accumulate(
            concat_inputs.begin(), concat_inputs.end(), 0, std::plus<>{}, [&](const auto& input) {
                return input.lens()[axis];
            });
        auto new_lens  = concat_inputs.front().lens();
Paul's avatar
Paul committed
64
65
66
67
68
69
70
71
72
73
74
75
76
        new_lens[axis] = new_dim_axis;
        return shape::from_permutation(type, new_lens, find_permutation(inputs));
    }
};
MIGRAPHX_REGISTER_OP(fused_concat);

namespace {

static unsigned int counter = 0;
struct find_pointwise_concat_pointwise
{
    auto matcher() const
    {
Paul's avatar
Format  
Paul committed
77
78
79
        auto concat = match::name("concat")(
            match::used_once(),
            match::any_of[match::inputs()](match::name("pointwise")(match::used_once())));
Paul's avatar
Paul committed
80
81
82
83
84
        return match::name("pointwise")(match::any_of[match::inputs()](concat.bind("concat")));
    }

    void apply(module_pass_manager& mpm, const match::matcher_result& r) const
    {
Paul's avatar
Format  
Paul committed
85
        auto ins        = r.result;
Paul's avatar
Paul committed
86
87
        auto concat_ins = r.instructions["concat"];

Paul's avatar
Format  
Paul committed
88
89
        auto concat_arg = std::find(ins->inputs().begin(), ins->inputs().end(), concat_ins) -
                          ins->inputs().begin();
Paul's avatar
Paul committed
90
        std::vector<instruction_ref> inputs;
Paul's avatar
Format  
Paul committed
91
        for(auto input : concat_ins->inputs())
Paul's avatar
Paul committed
92
93
94
95
96
97
        {
            if(input->name() == "pointwise")
                inputs.insert(inputs.end(), input->inputs().begin(), input->inputs().end());
            else
                inputs.push_back(input);
        }
Paul's avatar
Format  
Paul committed
98
99
100
101
        std::copy_if(ins->inputs().begin(),
                     ins->inputs().end(),
                     std::back_inserter(inputs),
                     [&](auto input) { return input != concat_ins; });
Paul's avatar
Paul committed
102
103

        std::vector<module_ref> module_inputs;
Paul's avatar
Format  
Paul committed
104
105
106
107
108
109
110
111
112
        std::transform(concat_ins->inputs().begin(),
                       concat_ins->inputs().end(),
                       std::back_inserter(module_inputs),
                       [&](instruction_ref input) {
                           if(input->name() == "pointwise")
                           {
                               auto* pm = input->module_inputs().front();
                               return mpm.create_module("concat:" + pm->name(), *pm);
                           }
Paul's avatar
Format  
Paul committed
113
114
                           auto* pm =
                               mpm.create_module("concat:identity" + std::to_string(counter++));
Paul's avatar
Format  
Paul committed
115

Paul's avatar
Paul committed
116
                           auto x  = pm->add_parameter("x0", shape{input->get_shape().type()});
Paul's avatar
Format  
Paul committed
117
118
119
120
121
122
123
                           auto id = pm->add_instruction(make_op("identity"), x);
                           pm->add_return({id});
                           return pm;
                       });

        auto* post_pm                  = ins->module_inputs().front();
        auto* rm                       = mpm.create_module(post_pm->name() + ":concat", *post_pm);
Paul's avatar
Paul committed
124
125
126
        std::vector<std::string> names = rm->get_parameter_names();
        std::sort(names.begin(), names.end());
        auto concat_param_name = names[concat_arg];
Paul's avatar
Format  
Paul committed
127
        auto concat_param      = rm->get_parameter(concat_param_name);
Paul's avatar
Paul committed
128
129
130
131
132
        auto param = rm->add_parameter("!" + concat_param_name, concat_param->get_shape());
        rm->replace_instruction(concat_param, param);
        rm->remove_instruction(concat_param);

        module_inputs.push_back(rm);
Paul's avatar
Format  
Paul committed
133
134
135
136
137
        mpm.get_module().replace_instruction(
            ins,
            make_op("fused_concat", concat_ins->normalized_operator().to_value()),
            inputs,
            module_inputs);
Paul's avatar
Paul committed
138
139
140
    }
};

Paul's avatar
Format  
Paul committed
141
} // namespace
Paul's avatar
Paul committed
142
143
144
145
146
147
148
149

void fuse_concat::apply(module_pass_manager& mpm) const
{
    match::find_matches(mpm, find_pointwise_concat_pointwise{});
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx