Commit f80061db authored by Paul's avatar Paul
Browse files

Format

parent 1ccbd8ce
......@@ -524,9 +524,9 @@ struct verify : command<verify>
auto m = c.parameters.generate(p, t, true, c.l.batch);
auto quantize = precision::fp32;
if (c.to_fp16)
if(c.to_fp16)
quantize = precision::fp16;
if (c.to_int8)
if(c.to_int8)
quantize = precision::int8;
if(per_instruction)
......
......@@ -40,15 +40,18 @@ void apply_quantizelinear(module& m, instruction_ref ins)
if(x->get_shape().type() != y_scale->get_shape().type())
{
x = m.insert_instruction(ins, make_op("convert", {{"target_type", y_scale->get_shape().type()}}), x);
x = m.insert_instruction(
ins, make_op("convert", {{"target_type", y_scale->get_shape().type()}}), x);
}
auto div = m.insert_instruction(ins, make_op("div"), x, y_scale);
auto add_zero_point = m.insert_instruction(ins, make_op("round"), div);
if(ins->inputs().size() == 3)
{
auto zero_point = m.insert_instruction(
ins, make_op("convert", {{"target_type", y_scale->get_shape().type()}}), ins->inputs()[2]);
auto zero_point =
m.insert_instruction(ins,
make_op("convert", {{"target_type", y_scale->get_shape().type()}}),
ins->inputs()[2]);
add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point);
}
......@@ -73,13 +76,15 @@ void apply_dequantizelinear(module& m, instruction_ref ins)
{
assert(ins->name() == "dequantizelinear");
auto x_scale = ins->inputs()[1];
auto x = m.insert_instruction(
auto x = m.insert_instruction(
ins, make_op("convert", {{"target_type", x_scale->get_shape().type()}}), ins->inputs()[0]);
if(ins->inputs().size() == 3)
{
auto x_zero_point = m.insert_instruction(
ins, make_op("convert", {{"target_type", x_scale->get_shape().type()}}), ins->inputs()[2]);
auto x_zero_point =
m.insert_instruction(ins,
make_op("convert", {{"target_type", x_scale->get_shape().type()}}),
ins->inputs()[2]);
x = m.insert_instruction(ins, make_op("sub"), x, x_zero_point);
}
......
......@@ -503,7 +503,9 @@ struct find_inner_broadcast
if(broadcasts.empty())
return;
// Skip if different data types are used
if (any_of(broadcasts, [&](auto i) { return i->get_shape().type() != broadcasts.front()->get_shape().type(); }))
if(any_of(broadcasts, [&](auto i) {
return i->get_shape().type() != broadcasts.front()->get_shape().type();
}))
return;
bool mixed_broadcasts = any_of(broadcasts, non_scalar_op("broadcast")) and
any_of(broadcasts, non_scalar_op("multibroadcast"));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment