quantize_ins.cpp 3.05 KB
Newer Older
1
2
3
4
#include <migraphx/quantize_ins.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
5
#include <migraphx/op/fp_conversion.hpp>
6
7
8
9
10
11
#include <migraphx/stringutils.hpp>
#include <utility>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

Shucai Xiao's avatar
Shucai Xiao committed
12
13
14
15
instruction_ref insert_fp16(program& prog,
                            instruction_ref& ins,
                            shape::type_t type,
                            std::unordered_map<instruction_ref, instruction_ref>& map_fp16)
16
{
Shucai Xiao's avatar
Shucai Xiao committed
17
    if(map_fp16.count(ins) > 0)
18
19
20
21
    {
        return map_fp16[ins];
    }

Shucai Xiao's avatar
Shucai Xiao committed
22
23
    assert(ins->get_shape().type() == shape::float_type ||
           ins->get_shape().type() == shape::double_type);
24
25
26
    instruction_ref ins_fp16{};
    if(ins == std::prev(prog.end()))
    {
27
        ins_fp16 = prog.add_instruction(op::fp_conversion{type}, ins);
28
29
30
31
32
33
34
35
36
37
    }
    else
    {
        ins_fp16 = prog.insert_instruction(std::next(ins), op::fp_conversion{}, ins);
    }
    map_fp16[ins] = ins_fp16;

    return ins_fp16;
}

38
39
void quantize_ins(program& prog, const std::vector<std::string>& ins_names)
{
40
    std::unordered_map<instruction_ref, instruction_ref> map_fp16;
Shucai Xiao's avatar
Shucai Xiao committed
41
    for(auto ins : iterator_for(prog))
42
    {
43
44
        auto name_it = std::find(ins_names.begin(), ins_names.end(), ins->name());
        if(name_it == ins_names.end())
45
46
47
        {
            continue;
        }
48

49
        shape::type_t orig_type = ins->get_shape().type();
Shucai Xiao's avatar
Shucai Xiao committed
50
51
        // process all inputs, if input is a fp32 or fp64, convert it
        // to a fp16 by adding a fp_conversion operator.
52
        auto inputs = ins->inputs();
Shucai Xiao's avatar
Shucai Xiao committed
53
        for(auto input : inputs)
54
55
        {
            auto s = input->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
56
            if(s.type() == shape::float_type || s.type() == shape::double_type)
57
            {
58
59
60
                // if the input is a fp_conversion operator, uses its input
                // as its current input
                instruction_ref input_fp16{};
Shucai Xiao's avatar
Shucai Xiao committed
61
                if(input->name() == "fp_conversion")
62
63
64
65
66
67
68
                {
                    input_fp16 = input->inputs().front();
                }
                else
                {
                    input_fp16 = insert_fp16(prog, input, shape::half_type, map_fp16);
                }
69
70
71
                instruction::replace_argument(ins, input, input_fp16, false);
            }
        }
72
        // recompute the output shape
73
74
        ins->recompute_ins_shape();

75
76
77
        // If output is not the original type, add another instruction
        // to convert it back to the original type
        if(ins->get_shape().type() != orig_type)
78
        {
79
            instruction_ref ins_orig_type{};
Shucai Xiao's avatar
Shucai Xiao committed
80
            if(ins == std::prev(prog.end()))
81
82
83
84
85
            {
                ins_orig_type = prog.add_instruction(op::fp_conversion{orig_type}, ins);
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
86
87
                ins_orig_type =
                    prog.insert_instruction(std::next(ins), op::fp_conversion{orig_type}, ins);
88
            }
Shucai Xiao's avatar
Shucai Xiao committed
89

90
            prog.replace_instruction(ins, ins_orig_type);
91
        }
92
93
94
95
96
    }
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx