Commit 4d18329c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add the function to quantize one instruction

parent 83ce18ef
...@@ -19,6 +19,7 @@ add_library(migraphx ...@@ -19,6 +19,7 @@ add_library(migraphx
instruction.cpp instruction.cpp
program.cpp program.cpp
quantization.cpp quantization.cpp
quantize_ins.cpp
shape.cpp shape.cpp
schedule.cpp schedule.cpp
simplify_algebra.cpp simplify_algebra.cpp
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/op/fp_conversion.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <utility> #include <utility>
...@@ -11,7 +12,8 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,7 +12,8 @@ inline namespace MIGRAPHX_INLINE_NS {
instruction_ref instruction_ref
insert_fp16(program& prog, insert_fp16(program& prog,
instruction_ref& ins, instruction_ref& ins,
shape::type_t type std::unordered_map<instruction_ref, instruction_ref>& map_fp16) shape::type_t type,
std::unordered_map<instruction_ref, instruction_ref>& map_fp16)
{ {
if(map_fp16.count(ins) > 0) if(map_fp16.count(ins) > 0)
{ {
...@@ -23,7 +25,7 @@ insert_fp16(program& prog, ...@@ -23,7 +25,7 @@ insert_fp16(program& prog,
instruction_ref ins_fp16{}; instruction_ref ins_fp16{};
if(ins == std::prev(prog.end())) if(ins == std::prev(prog.end()))
{ {
ins_fp16 = prog.add_instruction(op::fp_conversion{}, ins); ins_fp16 = prog.add_instruction(op::fp_conversion{type}, ins);
} }
else else
{ {
...@@ -39,26 +41,53 @@ void quantize_ins(program& prog, const std::vector<std::string>& ins_names) ...@@ -39,26 +41,53 @@ void quantize_ins(program& prog, const std::vector<std::string>& ins_names)
std::unordered_map<instruction_ref, instruction_ref> map_fp16; std::unordered_map<instruction_ref, instruction_ref> map_fp16;
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(prog))
{ {
auto name_it = std::find(ins_name.begin(), ins_name.end(), ins->name()); auto name_it = std::find(ins_names.begin(), ins_names.end(), ins->name());
if(name_it == ins_name.end()) if(name_it == ins_names.end())
{ {
continue; continue;
} }
shape::type_t orig_type = ins->get_shape().type();
// process all inputs, if input is a fp32 or fp64, convert it
// to a fp16 by adding a fp_conversion operator.
auto inputs = ins->inputs(); auto inputs = ins->inputs();
for(auto input : inputs) for(auto input : inputs)
{ {
auto s = input->get_shape(); auto s = input->get_shape();
if(s.type() == shape::float_type || s.type() == shape::double_type) if(s.type() == shape::float_type || s.type() == shape::double_type)
{ {
auto input_fp16 = insert_fp16(prog, input, s.type(), map_fp16); // if the input is a fp_conversion operator, uses its input
// as its current input
instruction_ref input_fp16{};
if (input->name() == "fp_conversion")
{
input_fp16 = input->inputs().front();
}
else
{
input_fp16 = insert_fp16(prog, input, shape::half_type, map_fp16);
}
instruction::replace_argument(ins, input, input_fp16, false); instruction::replace_argument(ins, input, input_fp16, false);
} }
} }
// recompute the output shape
ins->recompute_ins_shape(); ins->recompute_ins_shape();
if(ins->get_shape().type() == shape::half_type) // If output is not the original type, add another instruction
// to convert it back to the original type
if(ins->get_shape().type() != orig_type)
{ {
instruction_ref ins_orig_type{};
if (ins == std::prev(prog.end()))
{
ins_orig_type = prog.add_instruction(op::fp_conversion{orig_type}, ins);
}
else
{
ins_orig_type = prog.insert_instruction(std::next(ins), op::fp_conversion{orig_type}, ins);
}
prog.replace_instruction(ins, ins_orig_type);
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment