Commit 13e4b719 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix comments from code review.

parent 92fd618a
...@@ -18,8 +18,8 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -18,8 +18,8 @@ inline namespace MIGRAPHX_INLINE_NS {
instruction_ref convert_fp32_fp16(program& prog, instruction_ref& ins) instruction_ref convert_fp32_fp16(program& prog, instruction_ref& ins)
{ {
assert(ins->get_shape().type() == shape::float_type); assert(ins->get_shape().type() == shape::float_type || ins->get_shape().type() == shape::double_type);
assert(ins->name().front() == '@'); assert(contains({"@literal", "@param"}, ins->name()));
instruction_ref ins_fp16{}; instruction_ref ins_fp16{};
if(ins->name() == "@literal") if(ins->name() == "@literal")
{ {
...@@ -47,11 +47,15 @@ instruction_ref convert_fp32_fp16(program& prog, instruction_ref& ins) ...@@ -47,11 +47,15 @@ instruction_ref convert_fp32_fp16(program& prog, instruction_ref& ins)
void quantize(program& prog) void quantize(program& prog)
{ {
bool reduced_precision = false; bool reduced_precision = false;
shape::type_t orig_type = shape::float_type;
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(prog))
{ {
// convert float_type to half_type // convert float_type to half_type
if(ins->name().front() == '@' && ins->get_shape().type() == shape::float_type) if(contains({"@literal", "@param"}, ins->name()) &&
(ins->get_shape().type() == shape::float_type ||
ins->get_shape().type() == shape::double_type))
{ {
orig_type = ins->get_shape().type();
auto ins_fp16 = convert_fp32_fp16(prog, ins); auto ins_fp16 = convert_fp32_fp16(prog, ins);
auto outputs = ins->outputs(); auto outputs = ins->outputs();
for(auto output : outputs) for(auto output : outputs)
...@@ -71,7 +75,7 @@ void quantize(program& prog) ...@@ -71,7 +75,7 @@ void quantize(program& prog)
{ {
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(prog))
{ {
if(ins->name().front() != '@') if(!contains({"@literal", "@param"}, ins->name()))
{ {
ins->recompute_ins_shape(); ins->recompute_ins_shape();
} }
...@@ -80,7 +84,7 @@ void quantize(program& prog) ...@@ -80,7 +84,7 @@ void quantize(program& prog)
auto ins = std::prev(prog.end()); auto ins = std::prev(prog.end());
if(ins->get_shape().type() == shape::half_type) if(ins->get_shape().type() == shape::half_type)
{ {
prog.add_instruction(op::fp_conversion{shape::float_type}, ins); prog.add_instruction(op::fp_conversion{orig_type}, ins);
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment