quantization.cpp 2.71 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
4
#include <migraphx/program.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
5
#include <migraphx/op/fp_conversion.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#include <migraphx/target.hpp>
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/time.hpp>
#include <migraphx/iterator_for.hpp>
#include <iostream>
#include <sstream>
#include <algorithm>
#include <utility>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

instruction_ref convert_fp32_fp16(program& prog, instruction_ref& ins)
{
Shucai Xiao's avatar
Shucai Xiao committed
21
22
    assert(ins->get_shape().type() == shape::float_type || ins->get_shape().type() == shape::double_type);
    assert(contains({"@literal", "@param"}, ins->name()));
Shucai Xiao's avatar
Shucai Xiao committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    instruction_ref ins_fp16{};
    if(ins->name() == "@literal")
    {
        std::vector<float> values;
        auto l_fp32 = ins->get_literal();
        shape s     = ins->get_shape();
        l_fp32.visit([&](auto val) { values.assign(val.begin(), val.end()); });
        ins_fp16 = prog.add_literal(literal({shape::half_type, s.lens()}, values));
    }
    else if(ins->name() == "@param")
    {
        if(ins == std::prev(prog.end()))
        {
            ins_fp16 = prog.add_instruction(op::fp_conversion{}, ins);
        }
        else
        {
            ins_fp16 = prog.insert_instruction(std::next(ins), op::fp_conversion{}, ins);
        }
    }

    return ins_fp16;
}

void quantize(program& prog)
{
    bool reduced_precision = false;
Shucai Xiao's avatar
Shucai Xiao committed
50
    shape::type_t orig_type = shape::float_type;
Shucai Xiao's avatar
Shucai Xiao committed
51
52
53
    for(auto ins : iterator_for(prog))
    {
        // convert float_type to half_type
Shucai Xiao's avatar
Shucai Xiao committed
54
55
56
        if(contains({"@literal", "@param"}, ins->name()) && 
            (ins->get_shape().type() == shape::float_type ||
            ins->get_shape().type() == shape::double_type))
Shucai Xiao's avatar
Shucai Xiao committed
57
        {
Shucai Xiao's avatar
Shucai Xiao committed
58
            orig_type = ins->get_shape().type();
Shucai Xiao's avatar
Shucai Xiao committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
            auto ins_fp16 = convert_fp32_fp16(prog, ins);
            auto outputs  = ins->outputs();
            for(auto output : outputs)
            {
                if(output != ins_fp16)
                {
                    instruction::replace_argument(output, ins, ins_fp16, false);
                }
            }

            reduced_precision = true;
        }
    }

    // add another instruction at last to convert fp16 to fp32
    if(reduced_precision)
    {
        for(auto ins : iterator_for(prog))
        {
Shucai Xiao's avatar
Shucai Xiao committed
78
            if(!contains({"@literal", "@param"}, ins->name()))
Shucai Xiao's avatar
Shucai Xiao committed
79
80
81
82
83
84
85
86
            {
                ins->recompute_ins_shape();
            }
        }

        auto ins = std::prev(prog.end());
        if(ins->get_shape().type() == shape::half_type)
        {
Shucai Xiao's avatar
Shucai Xiao committed
87
            prog.add_instruction(op::fp_conversion{orig_type}, ins);
Shucai Xiao's avatar
Shucai Xiao committed
88
89
90
91
92
93
        }
    }
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx