eliminate_contiguous.cpp 6.91 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
Paul's avatar
Paul committed
24
25
26
27
28
29
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
Paul's avatar
Paul committed
30
31
#include <migraphx/op/contiguous.hpp>
#include <migraphx/op/identity.hpp>
32
#include <migraphx/par_for.hpp>
Paul's avatar
Paul committed
33
#include <type_traits>
Paul's avatar
Paul committed
34
#include <utility>
35

Paul's avatar
Paul committed
36
namespace migraphx {
Paul's avatar
Paul committed
37
inline namespace MIGRAPHX_INLINE_NS {
38

39
40
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS)

41
42
43
static bool try_compute_shape(instruction_ref ins,
                              const std::vector<shape>& inputs,
                              const std::vector<module_ref>& mods)
44
45
46
{
    try
    {
47
        shape new_shape = ins->get_operator().compute_shape(inputs, mods);
Charlie Lin's avatar
Charlie Lin committed
48
49
50
51
52
53
54

        // Cannot tell if a dynamic shape will need to be made contiguous
        if(new_shape.dynamic())
        {
            return false;
        }

Shucai Xiao's avatar
Shucai Xiao committed
55
56
        // If the output shape is a standard shape, no need to try its output
        if(new_shape.standard())
57
58
59
60
        {
            return true;
        }

61
        // if no changes for the shape, the contiguous can also be removed
Shucai Xiao's avatar
Shucai Xiao committed
62
        if(new_shape == ins->get_shape())
63
64
65
66
        {
            return true;
        }

67
        auto outputs = ins->outputs();
68
        // If the current instruction has no output, it means it is the last
69
70
        // instruction and generates a non-standard output shape, and the last
        // output shape is different from the case with the contiguous operator
Shucai Xiao's avatar
Shucai Xiao committed
71
        if(outputs.empty())
72
        {
73
            return false;
74
75
        }

Shucai Xiao's avatar
Shucai Xiao committed
76
        for(auto output : outputs)
77
78
        {
            auto args = output->inputs();
79
80
81
82
            std::vector<shape> input_shapes(args.size());
            std::transform(args.begin(), args.end(), input_shapes.begin(), [&](auto& arg) {
                return (arg == ins) ? new_shape : arg->get_shape();
            });
Shucai Xiao's avatar
Shucai Xiao committed
83

84
            if(not try_compute_shape(output, input_shapes, output->module_inputs()))
85
86
87
88
            {
                return false;
            }
        }
89
    }
90
91
92
93
94
95
96
97
    catch(const std::exception& e)
    {
        if(enabled(MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS{}))
        {
            std::cout << "Exception: " << e.what() << std::endl;
        }
        return false;
    }
98
99
    catch(...)
    {
100
101
102
103
        if(enabled(MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS{}))
        {
            std::cout << "Unknown exception" << std::endl;
        }
104
105
        return false;
    }
106

107
108
109
    return true;
}

110
111
112
static bool try_compute_shape(instruction_ref ins,
                              const std::vector<instruction_ref>& args,
                              const std::vector<module_ref>& mods)
113
114
{
    auto inputs = to_shapes(args);
115
    return try_compute_shape(ins, inputs, mods);
116
117
}

118
119
template <class F>
static void remove_contiguous(const std::string& op_name, module& m, F f)
120
{
121
122
    auto last = std::prev(m.end());
    std::vector<instruction_ref> const_instructions;
123

124
    for(auto ins : iterator_for(m))
125
    {
126
127
128
129
        // return instruction should have inputs with standard shape
        if(ins->name() == "@return")
            continue;

130
131
132
133
134
135
        if(ins != last and ins->outputs().empty())
            continue;

        if(not f(ins))
            continue;

136
        // Make a copy so we can modify it while we iterate
Shucai Xiao's avatar
Shucai Xiao committed
137
138
139
        auto args     = ins->inputs();
        auto new_args = args;
        auto mod_args = ins->module_inputs();
140

Paul's avatar
Paul committed
141
        for(auto arg : ins->inputs())
142
        {
143
144
            if(arg->name() != op_name)
                continue;
145
146
147
148
149
            if(enabled(MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS{}))
            {
                std::cout << "eliminate_contiguous: ";
                m.debug_print(ins);
            }
150
151
152
153
154
155
156
            auto prev = arg->inputs().front();
            replace(new_args, arg, prev);
            if(try_compute_shape(ins, new_args, mod_args))
            {
                instruction::replace_argument(ins, arg, prev);
            }
            else if(prev->can_eval())
157
            {
158
                const_instructions.push_back(arg);
159
160
161
            }
        }
    }
162

Charlie Lin's avatar
Charlie Lin committed
163
    // Perform static contiguous evaluations in parallel
164
165
    std::vector<argument> literals(const_instructions.size());
    par_for(const_instructions.size(), 1, [&](const auto i) {
Charlie Lin's avatar
Charlie Lin committed
166
167
168
169
170
171
172
173
        auto c    = op::contiguous{};
        auto prev = const_instructions[i]->inputs().front();
        // compute the output contiguous shape from the previous instruction shape
        shape computed_shape                   = c.compute_shape({prev->get_shape()});
        const std::vector<argument>& prev_eval = {prev->eval()};
        // prev_eval should not be used in make_compute_output_shape() as computed_shape is static
        auto co_shape = make_compute_output_shape(pack(c, computed_shape, prev_eval));
        literals[i]   = c.compute(co_shape, prev_eval);
174
175
    });

Charlie Lin's avatar
Charlie Lin committed
176
    // Replace static contiguous operations with a literal
177
    for(size_t i = 0; i < const_instructions.size(); i++)
178
179
    {
        auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
180
        m.replace_instruction(const_instructions[i], l);
181
    }
182
183
}

Paul's avatar
Paul committed
184
185
186
187
188
189
190
191
192
193
194
195
static void remove_contiguous_nops(const std::string& op_name, module& m)
{
    for(auto ins : iterator_for(m))
    {
        if(ins->name() != op_name)
            continue;
        if(ins->inputs().front()->get_shape() != ins->get_shape())
            continue;
        m.replace_instruction(ins, ins->inputs().front());
    }
}

196
197
198
199
200
201
202
203
204
void eliminate_contiguous::apply(module& m) const
{
    // Skip contiguous from splits first
    remove_contiguous(op_name, m, [](auto ins) {
        if(ins->name() != "slice")
            return true;
        return (ins->inputs().front()->outputs().size() == 1);
    });
    remove_contiguous(op_name, m, [](auto) { return true; });
Paul's avatar
Paul committed
205
    remove_contiguous_nops(op_name, m);
206
207
}

Paul's avatar
Paul committed
208
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
209
} // namespace migraphx