quantization.cpp 3.96 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
#include <migraphx/quantization.hpp>
2
3
4
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
5
#include <migraphx/op/convert.hpp>
6
#include <migraphx/stringutils.hpp>
7
#include <migraphx/ranges.hpp>
8
9
10
11
12
#include <utility>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

Shucai Xiao's avatar
Shucai Xiao committed
13
14
15
16
instruction_ref insert_fp16(program& prog,
                            instruction_ref& ins,
                            shape::type_t type,
                            std::unordered_map<instruction_ref, instruction_ref>& map_fp16)
17
{
Shucai Xiao's avatar
Shucai Xiao committed
18
    if(map_fp16.count(ins) > 0)
19
20
21
22
    {
        return map_fp16[ins];
    }

Shucai Xiao's avatar
Shucai Xiao committed
23
24
    assert(ins->get_shape().type() == shape::float_type ||
           ins->get_shape().type() == shape::double_type);
25
    instruction_ref ins_fp16{};
Shucai Xiao's avatar
Shucai Xiao committed
26
    if(ins->name() == "@literal" && ins->outputs().size() == 1)
27
    {
28
29
30
31
32
        std::vector<float> values;
        auto l_fp32 = ins->get_literal();
        shape s     = ins->get_shape();
        l_fp32.visit([&](auto val) { values.assign(val.begin(), val.end()); });
        ins_fp16 = prog.add_literal(literal({shape::half_type, s.lens()}, values));
33
34
35
    }
    else
    {
36
37
        if(ins == std::prev(prog.end()))
        {
38
            ins_fp16 = prog.add_instruction(op::convert{type}, ins);
39
40
41
        }
        else
        {
42
            ins_fp16 = prog.insert_instruction(std::next(ins), op::convert{}, ins);
43
        }
44
45
46
47
48
49
    }
    map_fp16[ins] = ins_fp16;

    return ins_fp16;
}

50
void quantize(program& prog, const std::vector<std::string>& ins_names)
51
{
52
    std::unordered_map<instruction_ref, instruction_ref> map_fp16;
Shucai Xiao's avatar
Shucai Xiao committed
53
    for(auto ins : iterator_for(prog))
54
    {
55
        // all indicates every instruction is converted
Shucai Xiao's avatar
Shucai Xiao committed
56
        if((not contains(ins_names, "all")) and (not contains(ins_names, ins->name())))
57
58
59
        {
            continue;
        }
60

61
        shape::type_t orig_type = ins->get_shape().type();
Shucai Xiao's avatar
Shucai Xiao committed
62
        // process all inputs, if input is a fp32 or fp64, convert it
63
        // to a fp16 by adding a convert operator.
64
        auto inputs = ins->inputs();
65
        std::vector<instruction_ref> converted_inputs;
Shucai Xiao's avatar
Shucai Xiao committed
66
        for(auto input : inputs)
67
68
        {
            auto s = input->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
69
            if(s.type() == shape::float_type || s.type() == shape::double_type)
70
            {
71
                // if the input is a convert operator, uses its input
72
73
                // as its current input
                instruction_ref input_fp16{};
74
                if(input->name() == "convert")
75
76
77
78
79
80
81
                {
                    input_fp16 = input->inputs().front();
                }
                else
                {
                    input_fp16 = insert_fp16(prog, input, shape::half_type, map_fp16);
                }
Shucai Xiao's avatar
Shucai Xiao committed
82
                // instruction::replace_argument(ins, input, input_fp16, false);
83
                converted_inputs.push_back(input_fp16);
84
            }
85
86
87
88
89
90
            else
            {
                converted_inputs.push_back(input);
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
91
92
        // no change for the input, return directly
        if (inputs == converted_inputs)
93
        {
Shucai Xiao's avatar
Shucai Xiao committed
94
95
96
97
98
99
100
101
102
            return;
        }

        auto op        = ins->get_operator();
        auto ins_shape = compute_shape(op, converted_inputs);
        if(ins_shape.type() != orig_type)
        {
            // insert another convert instruction to convert it back
            if(ins == std::prev(prog.end()))
103
            {
Shucai Xiao's avatar
Shucai Xiao committed
104
105
106
107
108
109
110
111
112
                prog.add_instruction(op::convert{orig_type}, ins);
            }
            else
            {
                // check the dead code case to avoid assert
                bool output_empty = ins->outputs().empty();
                auto ins_orig_type =
                    prog.insert_instruction(std::next(ins), op::convert{orig_type}, ins);
                if(!output_empty)
113
                {
Shucai Xiao's avatar
Shucai Xiao committed
114
                    prog.replace_instruction(ins, ins_orig_type);
115
                }
116
            }
117
        }
Shucai Xiao's avatar
Shucai Xiao committed
118
119

        prog.replace_instruction(ins, op, converted_inputs);
120
121
122
123
124
    }
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx