Commit f70710c8 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add a function to convert one instruction from fp32 to fp16

parent fbc9dad7
...@@ -18,7 +18,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -18,7 +18,7 @@ inline namespace MIGRAPHX_INLINE_NS {
instruction_ref convert_fp32_fp16(program& prog, instruction_ref& ins) instruction_ref convert_fp32_fp16(program& prog, instruction_ref& ins)
{ {
assert(ins->get_shape().type() == shape::float_type); assert(ins->get_shape().type() == shape::float_type || ins->get_shape().type() == shape::double_type);
assert(ins->name().front() == '@'); assert(ins->name().front() == '@');
instruction_ref ins_fp16{}; instruction_ref ins_fp16{};
if(ins->name() == "@literal") if(ins->name() == "@literal")
......
...@@ -8,8 +8,32 @@ ...@@ -8,8 +8,32 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
instruction_ref insert_fp16(program& prog, instruction_ref& ins, shape::type_t type
std::unordered_map<instruction_ref, instruction_ref>& map_fp16)
{
if (map_fp16.count(ins) > 0)
{
return map_fp16[ins];
}
assert(ins->get_shape().type() == shape::float_type || ins->get_shape().type() == shape::double_type);
instruction_ref ins_fp16{};
if(ins == std::prev(prog.end()))
{
ins_fp16 = prog.add_instruction(op::fp_conversion{}, ins);
}
else
{
ins_fp16 = prog.insert_instruction(std::next(ins), op::fp_conversion{}, ins);
}
map_fp16[ins] = ins_fp16;
return ins_fp16;
}
void quantize_ins(program& prog, const std::vector<std::string>& ins_names) void quantize_ins(program& prog, const std::vector<std::string>& ins_names)
{ {
std::unordered_map<instruction_ref, instruction_ref> map_fp16;
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(prog))
{ {
auto name_it = std::find(ins_name.begin(), ins_name.end(), ins->name()); auto name_it = std::find(ins_name.begin(), ins_name.end(), ins->name());
...@@ -17,6 +41,23 @@ void quantize_ins(program& prog, const std::vector<std::string>& ins_names) ...@@ -17,6 +41,23 @@ void quantize_ins(program& prog, const std::vector<std::string>& ins_names)
{ {
continue; continue;
} }
auto inputs = ins->inputs();
for (auto input : inputs)
{
auto s = input->get_shape();
if (s.type() == shape::float_type || s.type() == shape::double_type)
{
auto input_fp16 = insert_fp16(prog, input, s.type(), map_fp16);
instruction::replace_argument(ins, input, input_fp16, false);
}
}
ins->recompute_ins_shape();
if (ins->get_shape().type() == shape::half_type)
{
}
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment