Commit 9f48c99e authored by Alan Turner's avatar Alan Turner
Browse files

Convert input to scale type in rewrite_quantization

parent 08546656
...@@ -37,10 +37,11 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -37,10 +37,11 @@ void apply_quantizelinear(module& m, instruction_ref ins)
assert(ins->name() == "quantizelinear"); assert(ins->name() == "quantizelinear");
auto x = ins->inputs()[0]; auto x = ins->inputs()[0];
auto y_scale = ins->inputs()[1]; auto y_scale = ins->inputs()[1];
auto target_type = y_scale->get_shape().type();
if(x->get_shape().type() != y_scale->get_shape().type()) if(x->get_shape().type() != y_scale->get_shape().type())
{ {
x = m.insert_instruction(ins, make_op("convert", {{"target_type", shape::half_type}}), x); x = m.insert_instruction(ins, make_op("convert", {{"target_type", target_type}}), x);
} }
auto div = m.insert_instruction(ins, make_op("div"), x, y_scale); auto div = m.insert_instruction(ins, make_op("div"), x, y_scale);
auto add_zero_point = m.insert_instruction(ins, make_op("round"), div); auto add_zero_point = m.insert_instruction(ins, make_op("round"), div);
...@@ -48,7 +49,7 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -48,7 +49,7 @@ void apply_quantizelinear(module& m, instruction_ref ins)
if(ins->inputs().size() == 3) if(ins->inputs().size() == 3)
{ {
auto zero_point = m.insert_instruction( auto zero_point = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::half_type}}), ins->inputs()[2]); ins, make_op("convert", {{"target_type", target_type}}), ins->inputs()[2]);
add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point); add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point);
} }
...@@ -72,14 +73,16 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -72,14 +73,16 @@ void apply_quantizelinear(module& m, instruction_ref ins)
void apply_dequantizelinear(module& m, instruction_ref ins) void apply_dequantizelinear(module& m, instruction_ref ins)
{ {
assert(ins->name() == "dequantizelinear"); assert(ins->name() == "dequantizelinear");
auto x = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::half_type}}), ins->inputs()[0]);
auto x_scale = ins->inputs()[1]; auto x_scale = ins->inputs()[1];
auto target_type = x_scale->get_shape().type();
auto x = m.insert_instruction(
ins, make_op("convert", {{"target_type", target_type}}), ins->inputs()[0]);
if(ins->inputs().size() == 3) if(ins->inputs().size() == 3)
{ {
auto x_zero_point = m.insert_instruction( auto x_zero_point = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::half_type}}), ins->inputs()[2]); ins, make_op("convert", {{"target_type", target_type}}), ins->inputs()[2]);
x = m.insert_instruction(ins, make_op("sub"), x, x_zero_point); x = m.insert_instruction(ins, make_op("sub"), x, x_zero_point);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment