"include/ck/ck.hpp" did not exist on "76764d8c92a0a46497a172ebf8531ffc9c37cd60"
quantization.cpp 2.51 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
4
#include <migraphx/program.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
5
#include <migraphx/op/fp_conversion.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#include <migraphx/target.hpp>
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/time.hpp>
#include <migraphx/iterator_for.hpp>
#include <iostream>
#include <sstream>
#include <algorithm>
#include <utility>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

instruction_ref convert_fp32_fp16(program& prog, instruction_ref& ins)
{
Shucai Xiao's avatar
Shucai Xiao committed
21
22
    assert(ins->get_shape().type() == shape::float_type ||
           ins->get_shape().type() == shape::double_type);
Shucai Xiao's avatar
Shucai Xiao committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    assert(ins->name().front() == '@');
    instruction_ref ins_fp16{};
    if(ins->name() == "@literal")
    {
        std::vector<float> values;
        auto l_fp32 = ins->get_literal();
        shape s     = ins->get_shape();
        l_fp32.visit([&](auto val) { values.assign(val.begin(), val.end()); });
        ins_fp16 = prog.add_literal(literal({shape::half_type, s.lens()}, values));
    }
    else if(ins->name() == "@param")
    {
        if(ins == std::prev(prog.end()))
        {
            ins_fp16 = prog.add_instruction(op::fp_conversion{}, ins);
        }
        else
        {
            ins_fp16 = prog.insert_instruction(std::next(ins), op::fp_conversion{}, ins);
        }
    }

    return ins_fp16;
}

void quantize(program& prog)
{
    bool reduced_precision = false;
    for(auto ins : iterator_for(prog))
    {
        // convert float_type to half_type
        if(ins->name().front() == '@' && ins->get_shape().type() == shape::float_type)
        {
            auto ins_fp16 = convert_fp32_fp16(prog, ins);
            auto outputs  = ins->outputs();
            for(auto output : outputs)
            {
                if(output != ins_fp16)
                {
                    instruction::replace_argument(output, ins, ins_fp16, false);
                }
            }

            reduced_precision = true;
        }
    }

    // add another instruction at last to convert fp16 to fp32
    if(reduced_precision)
    {
        for(auto ins : iterator_for(prog))
        {
            if(ins->name().front() != '@')
            {
                ins->recompute_ins_shape();
            }
        }

        auto ins = std::prev(prog.end());
        if(ins->get_shape().type() == shape::half_type)
        {
84
            prog.add_instruction(op::fp_conversion{shape::float_type}, ins);
Shucai Xiao's avatar
Shucai Xiao committed
85
86
87
88
89
90
        }
    }
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx