"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "2fdf510d05a11280fff4688aa231491be98ef8d6"
quantization.cpp 5.31 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
#include <migraphx/quantization.hpp>
2
3
4
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
5
#include <migraphx/op/convert.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
6
#include <migraphx/op/capture.hpp>
7
#include <migraphx/stringutils.hpp>
8
#include <migraphx/ranges.hpp>
9
10
11
12
13
#include <utility>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

Shucai Xiao's avatar
Shucai Xiao committed
14
15
16
17
instruction_ref insert_fp16(program& prog,
                            instruction_ref& ins,
                            shape::type_t type,
                            std::unordered_map<instruction_ref, instruction_ref>& map_fp16)
18
{
Shucai Xiao's avatar
Shucai Xiao committed
19
    if(map_fp16.count(ins) > 0)
20
21
22
23
    {
        return map_fp16[ins];
    }

Shucai Xiao's avatar
Shucai Xiao committed
24
25
    assert(ins->get_shape().type() == shape::float_type ||
           ins->get_shape().type() == shape::double_type);
26
    instruction_ref ins_fp16{};
Shucai Xiao's avatar
Shucai Xiao committed
27
    ins_fp16      = prog.insert_instruction(std::next(ins), op::convert{type}, ins);
28
29
30
31
32
    map_fp16[ins] = ins_fp16;

    return ins_fp16;
}

33
void quantize(program& prog, const std::vector<std::string>& ins_names)
34
{
35
    std::unordered_map<instruction_ref, instruction_ref> map_fp16;
Shucai Xiao's avatar
Shucai Xiao committed
36
    for(auto ins : iterator_for(prog))
37
    {
38
        // all indicates every instruction is converted
Shucai Xiao's avatar
Shucai Xiao committed
39
        if((not contains(ins_names, "all")) and (not contains(ins_names, ins->name())))
40
41
42
        {
            continue;
        }
43

44
        shape::type_t orig_type = ins->get_shape().type();
Shucai Xiao's avatar
Shucai Xiao committed
45
        // process all inputs, if input is a fp32 or fp64, convert it
46
        // to a fp16 by adding a convert operator.
47
        auto inputs = ins->inputs();
48
        std::vector<instruction_ref> converted_inputs;
Shucai Xiao's avatar
Shucai Xiao committed
49
        for(auto input : inputs)
50
51
        {
            auto s = input->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
52
            if(s.type() == shape::float_type || s.type() == shape::double_type)
53
            {
54
                // if the input is a convert operator, uses its input
55
56
                // as its current input
                instruction_ref input_fp16{};
57
                if(input->name() == "convert")
58
59
60
61
62
63
64
                {
                    input_fp16 = input->inputs().front();
                }
                else
                {
                    input_fp16 = insert_fp16(prog, input, shape::half_type, map_fp16);
                }
65
                converted_inputs.push_back(input_fp16);
66
            }
67
68
69
70
71
72
            else
            {
                converted_inputs.push_back(input);
            }
        }

73
        // no change for the input, go to the next instruction
Shucai Xiao's avatar
Shucai Xiao committed
74
        if(inputs == converted_inputs)
75
        {
76
            continue;
Shucai Xiao's avatar
Shucai Xiao committed
77
78
79
80
81
82
83
84
        }

        auto op        = ins->get_operator();
        auto ins_shape = compute_shape(op, converted_inputs);
        if(ins_shape.type() != orig_type)
        {
            // insert another convert instruction to convert it back
            if(ins == std::prev(prog.end()))
85
            {
Shucai Xiao's avatar
Shucai Xiao committed
86
87
88
89
90
91
92
93
94
                prog.add_instruction(op::convert{orig_type}, ins);
            }
            else
            {
                // check the dead code case to avoid assert
                bool output_empty = ins->outputs().empty();
                auto ins_orig_type =
                    prog.insert_instruction(std::next(ins), op::convert{orig_type}, ins);
                if(!output_empty)
95
                {
Shucai Xiao's avatar
Shucai Xiao committed
96
                    prog.replace_instruction(ins, ins_orig_type);
97
                }
98
            }
99
        }
Shucai Xiao's avatar
Shucai Xiao committed
100
101

        prog.replace_instruction(ins, op, converted_inputs);
102
103
104
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
105
void quantize(program& prog) { quantize(prog, {"all"}); }
Shucai Xiao's avatar
Shucai Xiao committed
106

Shucai Xiao's avatar
Shucai Xiao committed
107
108
109
std::vector<std::vector<argument>> ins_args;
void capture_args(std::size_t ins_index, std::vector<argument> args)
{
Shucai Xiao's avatar
Shucai Xiao committed
110
    if(ins_index == ins_args.size())
Shucai Xiao's avatar
Shucai Xiao committed
111
112
113
114
115
116
117
118
    {
        ins_args.push_back(std::vector<argument>{});
    }
    ins_args[ins_index].push_back(args.front());

    return;
}

Shucai Xiao's avatar
Shucai Xiao committed
119
120
void calc_quant_params(std::vector<std::vector<argument>>& ins_arg,
                       std::vector<std::pair<float, float>>& ins_params)
Shucai Xiao's avatar
Shucai Xiao committed
121
122
123
124
125
126
127
128
129
130
{
    return;
}

// For the input of each input argument, we need to insert a
// capture operator to compute the scale and shift
void capture_arguments(program& prog, const std::vector<std::string>& ins_names)
{
    // the int8 quantization only support dot and convolution
    std::vector<std::string> op_names = {"dot", "convolution"};
Shucai Xiao's avatar
Shucai Xiao committed
131
132
133
    if(!std::all_of(ins_names.begin(), ins_names.end(), [&](auto name) {
           return std::find(op_names.begin(), op_names.end(), name) != op_names.end();
       }))
Shucai Xiao's avatar
Shucai Xiao committed
134
135
136
137
138
139
140
141
    {
        MIGRAPHX_THROW("CAPTURE_ARGUMENTS: input operator is not supported");
    }

    std::unordered_map<instruction_ref, instruction_ref> ins_map;
    std::size_t index = 0;
    for(auto ins : iterator_for(prog))
    {
Shucai Xiao's avatar
Shucai Xiao committed
142
        if(not contains(ins_names, ins->name()))
Shucai Xiao's avatar
Shucai Xiao committed
143
144
145
146
147
148
        {
            continue;
        }

        auto inputs = ins->inputs();
        std::vector<instruction_ref> new_args;
Shucai Xiao's avatar
Shucai Xiao committed
149
        for(auto input : inputs)
Shucai Xiao's avatar
Shucai Xiao committed
150
151
        {
            instruction_ref new_ins{};
Shucai Xiao's avatar
Shucai Xiao committed
152
            if(ins_map.count(input) > 0)
Shucai Xiao's avatar
Shucai Xiao committed
153
154
155
156
157
            {
                new_ins = ins_map[input];
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
158
159
                new_ins = prog.insert_instruction(
                    std::next(input), op::capture{index++, capture_args}, input);
Shucai Xiao's avatar
Shucai Xiao committed
160
161
162
163
164
165
166
167
                ins_map[input] = new_ins;
            }
            new_args.push_back(new_ins);
        }
        instruction::replace(ins, ins->get_operator(), ins->get_shape(), new_args);
    }
}

168
169
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx