Commit 13e4b719 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix comments from code review.

parent 92fd618a
......@@ -18,8 +18,8 @@ inline namespace MIGRAPHX_INLINE_NS {
instruction_ref convert_fp32_fp16(program& prog, instruction_ref& ins)
{
assert(ins->get_shape().type() == shape::float_type);
assert(ins->name().front() == '@');
assert(ins->get_shape().type() == shape::float_type || ins->get_shape().type() == shape::double_type);
assert(contains({"@literal", "@param"}, ins->name()));
instruction_ref ins_fp16{};
if(ins->name() == "@literal")
{
......@@ -47,11 +47,15 @@ instruction_ref convert_fp32_fp16(program& prog, instruction_ref& ins)
void quantize(program& prog)
{
bool reduced_precision = false;
shape::type_t orig_type = shape::float_type;
for(auto ins : iterator_for(prog))
{
// convert float_type to half_type
if(ins->name().front() == '@' && ins->get_shape().type() == shape::float_type)
if(contains({"@literal", "@param"}, ins->name()) &&
(ins->get_shape().type() == shape::float_type ||
ins->get_shape().type() == shape::double_type))
{
orig_type = ins->get_shape().type();
auto ins_fp16 = convert_fp32_fp16(prog, ins);
auto outputs = ins->outputs();
for(auto output : outputs)
......@@ -71,7 +75,7 @@ void quantize(program& prog)
{
for(auto ins : iterator_for(prog))
{
if(ins->name().front() != '@')
if(!contains({"@literal", "@param"}, ins->name()))
{
ins->recompute_ins_shape();
}
......@@ -80,7 +84,7 @@ void quantize(program& prog)
auto ins = std::prev(prog.end());
if(ins->get_shape().type() == shape::half_type)
{
prog.add_instruction(op::fp_conversion{shape::float_type}, ins);
prog.add_instruction(op::fp_conversion{orig_type}, ins);
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment