propagate_constant.cpp 5.55 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
Paul's avatar
Paul committed
24
#include <migraphx/propagate_constant.hpp>
Paul's avatar
Paul committed
25
26
27
#include <migraphx/program.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
Paul's avatar
Paul committed
28
#include <migraphx/functional.hpp>
29
#include <migraphx/par_for.hpp>
30
#include <migraphx/env.hpp>
31
#include <unordered_set>
charlie's avatar
charlie committed
32
#include <migraphx/make_op.hpp>
Paul's avatar
Paul committed
33

Paul's avatar
Paul committed
34
namespace migraphx {
Paul's avatar
Paul committed
35
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
36

37
38
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PROPAGATE_CONSTANT)

Paul's avatar
Paul committed
39
bool skip_propogate(instruction_ref ins)
Paul's avatar
Paul committed
40
{
Paul's avatar
Paul committed
41
    if(ins->name() == "contiguous")
Paul's avatar
Paul committed
42
        return skip_propogate(ins->inputs().front());
Paul's avatar
Paul committed
43
44
    auto&& s = ins->get_shape();
    if(s.broadcasted() and not s.scalar())
Paul's avatar
Paul committed
45
        return true;
Paul's avatar
Paul committed
46
    if(s.scalar() and s.elements() != 1)
Paul's avatar
Paul committed
47
48
49
        return true;
    return false;
}
Paul's avatar
Paul committed
50

51
bool is_const_ins(instruction_ref ins) { return ins->can_eval() and not skip_propogate(ins); }
52

53
void propagate_constant::apply(module& m) const
Paul's avatar
Paul committed
54
{
55
56
57
58
    std::unordered_set<instruction_ref> const_instrs;
    auto last = std::prev(m.end());

    // Find instructions that can be evaluated to a literal
59
    for(auto i : iterator_for(m))
60
    {
61
62
        const bool is_const = is_const_ins(i);
        if(is_const and i != last)
63
            continue;
64

65
66
67
68
69
70
71
72
73
74
75
76
77
        if(i == last and is_const)
        {
            const_instrs.insert(i);
        }
        else
        {
            std::copy_if(i->inputs().begin(),
                         i->inputs().end(),
                         std::inserter(const_instrs, const_instrs.begin()),
                         [&](const instruction_ref ins) {
                             return is_const_ins(ins) and ins->name() != "@literal";
                         });
        }
78
79
80
81
82
    }

    // Compute literals in parallel
    std::vector<instruction_ref> const_instrs_vec{const_instrs.begin(), const_instrs.end()};
    std::vector<argument> literals(const_instrs_vec.size());
charlie's avatar
charlie committed
83

charlie's avatar
charlie committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    // DEBUG
    // for(int i = 0; i < const_instrs_vec.size(); ++i)
    //{
    //     auto ins = const_instrs_vec[i];
    //     if(ins->get_shape().type() == shape::half_type)
    //     {
    //         auto inputs = ins->inputs();
    //         std::vector<instruction_ref> new_inputs(inputs.size());
    //         std::vector<instruction_ref> added_instructions;
    //         std::transform(inputs.begin(), inputs.end(), new_inputs.begin(), [&](auto input) {
    //             auto input_type = input->get_shape().type();
    //             if(input_type != shape::half_type and input_type != shape::float_type)
    //                 return input;
    //             auto ai = m.add_instruction(
    //                 make_op("convert", {{"target_type", shape::double_type}}), input);
    //             added_instructions.push_back(ai);
    //             return ai;
    //         });
    //         auto new_ins = m.add_instruction(ins->get_operator(), new_inputs);
    //         added_instructions.push_back(new_ins);
    //         auto after_convert = m.add_instruction(
    //             make_op("convert", {{"target_type", ins->get_shape().type()}}), new_ins);
    //         added_instructions.push_back(after_convert);
    //         literals[i] = after_convert->eval();
    //         for(auto a_ins : added_instructions)
    //         {
    //             m.remove_instruction(a_ins);
    //         }
    //     }
    //     else
    //     {
    //         literals[i] = const_instrs_vec[i]->eval();
    //     }
    // }

    // Original
    par_for(const_instrs_vec.size(), 1, [&](const auto i) {
        literals[i] = const_instrs_vec[i]->eval();
    });
123
124
125
126
127
128

    // Replace instructions in m
    for(size_t i = 0; i < const_instrs_vec.size(); i++)
    {
        if(not literals[i].empty())
        {
129
130
131
132
133
134
135
136
137
138
139
140
141
            if(enabled(MIGRAPHX_TRACE_PROPAGATE_CONSTANT{}))
            {
                std::cout << "Constant replace: " << std::endl;
                std::vector<instruction_ref> inss;
                fix([&](auto self, auto ins) {
                    if(contains(inss, ins))
                        return;
                    for(auto input : ins->inputs())
                        self(input);
                    inss.push_back(ins);
                })(const_instrs_vec[i]);
                m.debug_print(inss);
            }
142
143
144
145
            assert(literals[i].get_shape() == const_instrs_vec[i]->get_shape());
            auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
            m.replace_instruction(const_instrs_vec[i], l);
        }
146
    }
Paul's avatar
Paul committed
147
}
Paul's avatar
Paul committed
148

Paul's avatar
Paul committed
149
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
150
} // namespace migraphx