quantization.cpp 2.73 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
4
#include <migraphx/program.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
5
#include <migraphx/op/fp_conversion.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#include <migraphx/target.hpp>
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/time.hpp>
#include <migraphx/iterator_for.hpp>
#include <iostream>
#include <sstream>
#include <algorithm>
#include <utility>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

instruction_ref convert_fp32_fp16(program& prog, instruction_ref& ins)
{
Shucai Xiao's avatar
Shucai Xiao committed
21
22
    assert(ins->get_shape().type() == shape::float_type ||
           ins->get_shape().type() == shape::double_type);
Shucai Xiao's avatar
Shucai Xiao committed
23
    assert(contains({"@literal", "@param"}, ins->name()));
Shucai Xiao's avatar
Shucai Xiao committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    instruction_ref ins_fp16{};
    if(ins->name() == "@literal")
    {
        std::vector<float> values;
        auto l_fp32 = ins->get_literal();
        shape s     = ins->get_shape();
        l_fp32.visit([&](auto val) { values.assign(val.begin(), val.end()); });
        ins_fp16 = prog.add_literal(literal({shape::half_type, s.lens()}, values));
    }
    else if(ins->name() == "@param")
    {
        if(ins == std::prev(prog.end()))
        {
            ins_fp16 = prog.add_instruction(op::fp_conversion{}, ins);
        }
        else
        {
            ins_fp16 = prog.insert_instruction(std::next(ins), op::fp_conversion{}, ins);
        }
    }

    return ins_fp16;
}

void quantize(program& prog)
{
Shucai Xiao's avatar
Shucai Xiao committed
50
    bool reduced_precision  = false;
Shucai Xiao's avatar
Shucai Xiao committed
51
    shape::type_t orig_type = shape::float_type;
Shucai Xiao's avatar
Shucai Xiao committed
52
53
54
    for(auto ins : iterator_for(prog))
    {
        // convert float_type to half_type
Shucai Xiao's avatar
Shucai Xiao committed
55
56
        if(contains({"@literal", "@param"}, ins->name()) &&
           (ins->get_shape().type() == shape::float_type ||
Shucai Xiao's avatar
Shucai Xiao committed
57
            ins->get_shape().type() == shape::double_type))
Shucai Xiao's avatar
Shucai Xiao committed
58
        {
Shucai Xiao's avatar
Shucai Xiao committed
59
            orig_type     = ins->get_shape().type();
Shucai Xiao's avatar
Shucai Xiao committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
            auto ins_fp16 = convert_fp32_fp16(prog, ins);
            auto outputs  = ins->outputs();
            for(auto output : outputs)
            {
                if(output != ins_fp16)
                {
                    instruction::replace_argument(output, ins, ins_fp16, false);
                }
            }

            reduced_precision = true;
        }
    }

    // add another instruction at last to convert fp16 to fp32
    if(reduced_precision)
    {
        for(auto ins : iterator_for(prog))
        {
Shucai Xiao's avatar
Shucai Xiao committed
79
            if(!contains({"@literal", "@param"}, ins->name()))
Shucai Xiao's avatar
Shucai Xiao committed
80
81
82
83
84
85
86
87
            {
                ins->recompute_ins_shape();
            }
        }

        auto ins = std::prev(prog.end());
        if(ins->get_shape().type() == shape::half_type)
        {
Shucai Xiao's avatar
Shucai Xiao committed
88
            prog.add_instruction(op::fp_conversion{orig_type}, ins);
Shucai Xiao's avatar
Shucai Xiao committed
89
90
91
92
93
94
        }
    }
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx