"mmdet3d/ops/vscode:/vscode.git/clone" did not exist on "53271e3d8f7a448a2f6fa5bd0bc15621b89ad39c"
quantization.cpp 5.28 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
#include <migraphx/quantization.hpp>
2
3
4
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
5
#include <migraphx/op/convert.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
6
#include <migraphx/op/capture.hpp>
7
#include <migraphx/stringutils.hpp>
8
#include <migraphx/ranges.hpp>
9
10
11
12
13
#include <utility>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

Shucai Xiao's avatar
Shucai Xiao committed
14
15
16
17
instruction_ref insert_fp16(program& prog,
                            instruction_ref& ins,
                            shape::type_t type,
                            std::unordered_map<instruction_ref, instruction_ref>& map_fp16)
18
{
Shucai Xiao's avatar
Shucai Xiao committed
19
    if(map_fp16.count(ins) > 0)
20
21
22
23
    {
        return map_fp16[ins];
    }

Shucai Xiao's avatar
Shucai Xiao committed
24
25
    assert(ins->get_shape().type() == shape::float_type ||
           ins->get_shape().type() == shape::double_type);
26
    instruction_ref ins_fp16{};
Shucai Xiao's avatar
Shucai Xiao committed
27
    ins_fp16      = prog.insert_instruction(std::next(ins), op::convert{type}, ins);
28
29
30
31
32
    map_fp16[ins] = ins_fp16;

    return ins_fp16;
}

33
void quantize(program& prog, const std::vector<std::string>& ins_names)
34
{
35
    std::unordered_map<instruction_ref, instruction_ref> map_fp16;
Shucai Xiao's avatar
Shucai Xiao committed
36
    for(auto ins : iterator_for(prog))
37
    {
38
        // all indicates every instruction is converted
Shucai Xiao's avatar
Shucai Xiao committed
39
        if((not contains(ins_names, "all")) and (not contains(ins_names, ins->name())))
40
41
42
        {
            continue;
        }
43

44
        shape::type_t orig_type = ins->get_shape().type();
Shucai Xiao's avatar
Shucai Xiao committed
45
        // process all inputs, if input is a fp32 or fp64, convert it
46
        // to a fp16 by adding a convert operator.
47
        auto inputs = ins->inputs();
48
        std::vector<instruction_ref> converted_inputs;
Shucai Xiao's avatar
Shucai Xiao committed
49
        for(auto input : inputs)
50
51
        {
            auto s = input->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
52
            if(s.type() == shape::float_type || s.type() == shape::double_type)
53
            {
54
                // if the input is a convert operator, uses its input
55
56
                // as its current input
                instruction_ref input_fp16{};
57
                if(input->name() == "convert")
58
59
60
61
62
63
64
                {
                    input_fp16 = input->inputs().front();
                }
                else
                {
                    input_fp16 = insert_fp16(prog, input, shape::half_type, map_fp16);
                }
65
                converted_inputs.push_back(input_fp16);
66
            }
67
68
69
70
71
72
            else
            {
                converted_inputs.push_back(input);
            }
        }

73
        // no change for the input, go to the next instruction
Shucai Xiao's avatar
Shucai Xiao committed
74
        if(inputs == converted_inputs)
75
        {
76
            continue;
Shucai Xiao's avatar
Shucai Xiao committed
77
78
79
80
81
82
83
84
        }

        auto op        = ins->get_operator();
        auto ins_shape = compute_shape(op, converted_inputs);
        if(ins_shape.type() != orig_type)
        {
            // insert another convert instruction to convert it back
            if(ins == std::prev(prog.end()))
85
            {
Shucai Xiao's avatar
Shucai Xiao committed
86
87
88
89
90
91
92
93
94
                prog.add_instruction(op::convert{orig_type}, ins);
            }
            else
            {
                // check the dead code case to avoid assert
                bool output_empty = ins->outputs().empty();
                auto ins_orig_type =
                    prog.insert_instruction(std::next(ins), op::convert{orig_type}, ins);
                if(!output_empty)
95
                {
Shucai Xiao's avatar
Shucai Xiao committed
96
                    prog.replace_instruction(ins, ins_orig_type);
97
                }
98
            }
99
        }
Shucai Xiao's avatar
Shucai Xiao committed
100
101

        prog.replace_instruction(ins, op, converted_inputs);
102
103
104
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
105
void quantize(program& prog) { quantize(prog, {"all"}); }
Shucai Xiao's avatar
Shucai Xiao committed
106

Shucai Xiao's avatar
Shucai Xiao committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
std::vector<std::vector<argument> > ins_args;
void capture_args(std::size_t ins_index, std::vector<argument> args) {
    if (ins_index = ins_args.size())
    {
        ins_args.push_back(std::vector<argument>{});
    }
    ins_args[ins_index].push_back(args.front());

    return;
}

void calc_quant_params(std::vector<std::vector<argument>>&ins_arg, std::vector<std::pair<float, float>>& ins_params)
{
    return;
}

// For the input of each input argument, we need to insert a
// capture operator to compute the scale and shift
void capture_arguments(program& prog, const std::vector<std::string>& ins_names)
{
    // the int8 quantization only support dot and convolution
    std::vector<std::string> op_names = {"dot", "convolution"};
    if (!std::all_of(ins_names.begin(), ins_names.end(), [&](auto name) {
        return std::find(op_names.begin(), op_names.end(), name) != op_names.end();
    }))
    {
        MIGRAPHX_THROW("CAPTURE_ARGUMENTS: input operator is not supported");
    }

    std::unordered_map<instruction_ref, instruction_ref> ins_map;
    std::size_t index = 0;
    for(auto ins : iterator_for(prog))
    {
        if (not contains(ins_names, ins->name()))
        {
            continue;
        }

        auto inputs = ins->inputs();
        std::vector<instruction_ref> new_args;
        for (auto input : inputs)
        {
            instruction_ref new_ins{};
            if (ins_map.count(input) > 0)
            {
                new_ins = ins_map[input];
            }
            else
            {
                new_ins = prog.insert_instruction(std::next(input), op::capture{index++, capture_args}, input);                
                ins_map[input] = new_ins;
            }
            new_args.push_back(new_ins);
        }
        instruction::replace(ins, ins->get_operator(), ins->get_shape(), new_args);
    }
}

165
166
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx