"...frontend/git@developer.sourcefind.cn:Wenxuan/LightX2V.git" did not exist on "2a31ba430a2a7e66a84e068622c43a90ed95dfb1"
fuse_concat.cpp 5.8 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#include <migraphx/fuse_concat.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/module.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

struct fused_concat
{
    int64_t axis = 0;

    std::string name() const { return "fused_concat"; }

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.axis, "axis"));
    }

    shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
    {
        check_shapes{inputs, *this}.same_ndims();
Paul's avatar
Format  
Paul committed
30
        if((inputs.size() + 1) == mods.size())
Paul's avatar
Paul committed
31
32
33
            MIGRAPHX_THROW("FUSED_CONCAT: Missing fused modules");
        auto input_iter = inputs.begin();
        std::vector<shape> concat_inputs;
Paul's avatar
Format  
Paul committed
34
        for(module_ref mod : range(mods.begin(), mods.end() - 1))
Paul's avatar
Paul committed
35
36
37
38
        {
            concat_inputs.push_back(*input_iter);
            input_iter += mod->get_parameter_names().size();
        }
Paul's avatar
Format  
Paul committed
39
40
        module_ref post_mod          = mods.back();
        auto type                    = std::prev(post_mod->end())->get_shape().type();
Paul's avatar
Paul committed
41
        const auto& first_shape_lens = concat_inputs.front().lens();
Paul's avatar
Paul committed
42
        auto mismatch_it = std::find_if_not(concat_inputs.begin() + 1, concat_inputs.end(), [&](auto s) {
Paul's avatar
Format  
Paul committed
43
44
45
46
47
48
49
50
51
               const auto& lens = s.lens();
               return std::equal(lens.begin(),
                                 lens.begin() + axis,
                                 first_shape_lens.begin(),
                                 first_shape_lens.begin() + axis) and
                      std::equal(lens.begin() + axis + 1,
                                 lens.end(),
                                 first_shape_lens.begin() + axis + 1,
                                 first_shape_lens.end());
Paul's avatar
Paul committed
52
53
54
           });
        if (mismatch_it != concat_inputs.end())
            MIGRAPHX_THROW("FUSED_CONCAT: all input dimensions should match along non-axis of " + std::to_string(axis) + ": {" + to_string_range(first_shape_lens) + "} != {" + to_string_range(mismatch_it->lens()) + "}");
Paul's avatar
Format  
Paul committed
55
56
57
58
59
60

        std::size_t new_dim_axis = transform_accumulate(
            concat_inputs.begin(), concat_inputs.end(), 0, std::plus<>{}, [&](const auto& input) {
                return input.lens()[axis];
            });
        auto new_lens  = concat_inputs.front().lens();
Paul's avatar
Paul committed
61
62
63
64
65
66
67
68
69
70
71
72
73
        new_lens[axis] = new_dim_axis;
        return shape::from_permutation(type, new_lens, find_permutation(inputs));
    }
};
MIGRAPHX_REGISTER_OP(fused_concat);

namespace {

static unsigned int counter = 0;
struct find_pointwise_concat_pointwise
{
    auto matcher() const
    {
Paul's avatar
Format  
Paul committed
74
75
76
        auto concat = match::name("concat")(
            match::used_once(),
            match::any_of[match::inputs()](match::name("pointwise")(match::used_once())));
Paul's avatar
Paul committed
77
78
79
80
81
        return match::name("pointwise")(match::any_of[match::inputs()](concat.bind("concat")));
    }

    void apply(module_pass_manager& mpm, const match::matcher_result& r) const
    {
Paul's avatar
Format  
Paul committed
82
        auto ins        = r.result;
Paul's avatar
Paul committed
83
84
        auto concat_ins = r.instructions["concat"];

Paul's avatar
Format  
Paul committed
85
86
        auto concat_arg = std::find(ins->inputs().begin(), ins->inputs().end(), concat_ins) -
                          ins->inputs().begin();
Paul's avatar
Paul committed
87
        std::vector<instruction_ref> inputs;
Paul's avatar
Format  
Paul committed
88
        for(auto input : concat_ins->inputs())
Paul's avatar
Paul committed
89
90
91
92
93
94
        {
            if(input->name() == "pointwise")
                inputs.insert(inputs.end(), input->inputs().begin(), input->inputs().end());
            else
                inputs.push_back(input);
        }
Paul's avatar
Format  
Paul committed
95
96
97
98
        std::copy_if(ins->inputs().begin(),
                     ins->inputs().end(),
                     std::back_inserter(inputs),
                     [&](auto input) { return input != concat_ins; });
Paul's avatar
Paul committed
99
100

        std::vector<module_ref> module_inputs;
Paul's avatar
Format  
Paul committed
101
102
103
104
105
106
107
108
109
        std::transform(concat_ins->inputs().begin(),
                       concat_ins->inputs().end(),
                       std::back_inserter(module_inputs),
                       [&](instruction_ref input) {
                           if(input->name() == "pointwise")
                           {
                               auto* pm = input->module_inputs().front();
                               return mpm.create_module("concat:" + pm->name(), *pm);
                           }
Paul's avatar
Paul committed
110
                           auto* pm = mpm.create_module("concat:identity" + std::to_string(counter++));
Paul's avatar
Format  
Paul committed
111

Paul's avatar
Paul committed
112
                           auto x  = pm->add_parameter("x0", shape{input->get_shape().type()});
Paul's avatar
Format  
Paul committed
113
114
115
116
117
118
119
                           auto id = pm->add_instruction(make_op("identity"), x);
                           pm->add_return({id});
                           return pm;
                       });

        auto* post_pm                  = ins->module_inputs().front();
        auto* rm                       = mpm.create_module(post_pm->name() + ":concat", *post_pm);
Paul's avatar
Paul committed
120
121
122
        std::vector<std::string> names = rm->get_parameter_names();
        std::sort(names.begin(), names.end());
        auto concat_param_name = names[concat_arg];
Paul's avatar
Format  
Paul committed
123
        auto concat_param      = rm->get_parameter(concat_param_name);
Paul's avatar
Paul committed
124
125
126
127
128
        auto param = rm->add_parameter("!" + concat_param_name, concat_param->get_shape());
        rm->replace_instruction(concat_param, param);
        rm->remove_instruction(concat_param);

        module_inputs.push_back(rm);
Paul's avatar
Format  
Paul committed
129
130
131
132
133
        mpm.get_module().replace_instruction(
            ins,
            make_op("fused_concat", concat_ins->normalized_operator().to_value()),
            inputs,
            module_inputs);
Paul's avatar
Paul committed
134
135
136
    }
};

Paul's avatar
Format  
Paul committed
137
} // namespace
Paul's avatar
Paul committed
138
139
140
141
142
143
144
145

void fuse_concat::apply(module_pass_manager& mpm) const
{
    match::find_matches(mpm, find_pointwise_concat_pointwise{});
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx