Commit 4d18329c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add the function to quantize one instruction

parent 83ce18ef
......@@ -19,6 +19,7 @@ add_library(migraphx
instruction.cpp
program.cpp
quantization.cpp
quantize_ins.cpp
shape.cpp
schedule.cpp
simplify_algebra.cpp
......
......@@ -2,6 +2,7 @@
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/op/fp_conversion.hpp>
#include <migraphx/stringutils.hpp>
#include <utility>
......@@ -11,7 +12,8 @@ inline namespace MIGRAPHX_INLINE_NS {
instruction_ref
insert_fp16(program& prog,
instruction_ref& ins,
shape::type_t type std::unordered_map<instruction_ref, instruction_ref>& map_fp16)
shape::type_t type,
std::unordered_map<instruction_ref, instruction_ref>& map_fp16)
{
if(map_fp16.count(ins) > 0)
{
......@@ -23,7 +25,7 @@ insert_fp16(program& prog,
instruction_ref ins_fp16{};
if(ins == std::prev(prog.end()))
{
ins_fp16 = prog.add_instruction(op::fp_conversion{}, ins);
ins_fp16 = prog.add_instruction(op::fp_conversion{type}, ins);
}
else
{
......@@ -39,26 +41,53 @@ void quantize_ins(program& prog, const std::vector<std::string>& ins_names)
std::unordered_map<instruction_ref, instruction_ref> map_fp16;
for(auto ins : iterator_for(prog))
{
auto name_it = std::find(ins_name.begin(), ins_name.end(), ins->name());
if(name_it == ins_name.end())
auto name_it = std::find(ins_names.begin(), ins_names.end(), ins->name());
if(name_it == ins_names.end())
{
continue;
}
shape::type_t orig_type = ins->get_shape().type();
// process all inputs, if input is a fp32 or fp64, convert it
// to a fp16 by adding a fp_conversion operator.
auto inputs = ins->inputs();
for(auto input : inputs)
{
auto s = input->get_shape();
if(s.type() == shape::float_type || s.type() == shape::double_type)
{
auto input_fp16 = insert_fp16(prog, input, s.type(), map_fp16);
// if the input is a fp_conversion operator, uses its input
// as its current input
instruction_ref input_fp16{};
if (input->name() == "fp_conversion")
{
input_fp16 = input->inputs().front();
}
else
{
input_fp16 = insert_fp16(prog, input, shape::half_type, map_fp16);
}
instruction::replace_argument(ins, input, input_fp16, false);
}
}
// recompute the output shape
ins->recompute_ins_shape();
if(ins->get_shape().type() == shape::half_type)
// If output is not the original type, add another instruction
// to convert it back to the original type
if(ins->get_shape().type() != orig_type)
{
instruction_ref ins_orig_type{};
if (ins == std::prev(prog.end()))
{
ins_orig_type = prog.add_instruction(op::fp_conversion{orig_type}, ins);
}
else
{
ins_orig_type = prog.insert_instruction(std::next(ins), op::fp_conversion{orig_type}, ins);
}
prog.replace_instruction(ins, ins_orig_type);
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment