"test/vscode:/vscode.git/clone" did not exist on "89120fa1a72b3bb0b9caed920942e13d61e88c88"
Unverified Commit 26c1efa5 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Enable quantizing both int8 and fp16 in the driver (#1757)

* Allow quantizing for both int8 and fp16
parent 763dd1da
......@@ -415,7 +415,8 @@ struct compiler
program_params parameters;
compiler_target ct;
compile_options co;
precision quantize = precision::fp32;
bool to_fp16 = false;
bool to_int8 = false;
std::vector<std::string> fill0;
std::vector<std::string> fill1;
......@@ -436,8 +437,8 @@ struct compiler
{"--exhaustive-tune"},
ap.help("Exhastively search for best tuning parameters for kernels"),
ap.set_value(true));
ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(precision::fp16));
ap(quantize, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(precision::int8));
ap(to_fp16, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(true));
ap(to_int8, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(true));
}
auto params(const program& p)
......@@ -445,6 +446,11 @@ struct compiler
return parameters.generate(p, ct.get_target(), co.offload_copy, l.batch);
}
auto host_params(const program& p)
{
return parameters.generate(p, ct.get_target(), true, l.batch);
}
program compile()
{
auto p = l.load();
......@@ -452,13 +458,13 @@ struct compiler
if(p.is_compiled())
return p;
auto t = ct.get_target();
if(quantize == precision::fp16)
if(to_fp16)
{
quantize_fp16(p);
}
else if(quantize == precision::int8)
if(to_int8)
{
quantize_int8(p, t, {params(p)});
quantize_int8(p, t, {host_params(p)});
}
p.compile(t, co);
l.save(p);
......@@ -517,17 +523,23 @@ struct verify : command<verify>
auto t = c.ct.get_target();
auto m = c.parameters.generate(p, t, true, c.l.batch);
auto quantize = precision::fp32;
if(c.to_fp16)
quantize = precision::fp16;
if(c.to_int8)
quantize = precision::int8;
if(per_instruction)
{
verify_instructions(p, t, c.co, c.quantize, tolerance);
verify_instructions(p, t, c.co, quantize, tolerance);
}
else if(reduce)
{
verify_reduced_program(p, t, c.co, c.quantize, m, tolerance);
verify_reduced_program(p, t, c.co, quantize, m, tolerance);
}
else
{
verify_program(c.l.file, p, t, c.co, c.quantize, m, tolerance);
verify_program(c.l.file, p, t, c.co, quantize, m, tolerance);
}
}
};
......
......@@ -40,15 +40,18 @@ void apply_quantizelinear(module& m, instruction_ref ins)
if(x->get_shape().type() != y_scale->get_shape().type())
{
x = m.insert_instruction(ins, make_op("convert", {{"target_type", shape::float_type}}), x);
x = m.insert_instruction(
ins, make_op("convert", {{"target_type", y_scale->get_shape().type()}}), x);
}
auto div = m.insert_instruction(ins, make_op("div"), x, y_scale);
auto add_zero_point = m.insert_instruction(ins, make_op("round"), div);
if(ins->inputs().size() == 3)
{
auto zero_point = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[2]);
auto zero_point =
m.insert_instruction(ins,
make_op("convert", {{"target_type", y_scale->get_shape().type()}}),
ins->inputs()[2]);
add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point);
}
......@@ -72,14 +75,16 @@ void apply_quantizelinear(module& m, instruction_ref ins)
void apply_dequantizelinear(module& m, instruction_ref ins)
{
assert(ins->name() == "dequantizelinear");
auto x = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[0]);
auto x_scale = ins->inputs()[1];
auto x = m.insert_instruction(
ins, make_op("convert", {{"target_type", x_scale->get_shape().type()}}), ins->inputs()[0]);
if(ins->inputs().size() == 3)
{
auto x_zero_point = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[2]);
auto x_zero_point =
m.insert_instruction(ins,
make_op("convert", {{"target_type", x_scale->get_shape().type()}}),
ins->inputs()[2]);
x = m.insert_instruction(ins, make_op("sub"), x, x_zero_point);
}
......
......@@ -501,6 +501,11 @@ struct find_inner_broadcast
auto broadcasts = ins->inputs();
if(broadcasts.empty())
return;
// Skip if different data types are used
if(any_of(broadcasts, [&](auto i) {
return i->get_shape().type() != broadcasts.front()->get_shape().type();
}))
return;
bool mixed_broadcasts = any_of(broadcasts, non_scalar_op("broadcast")) and
any_of(broadcasts, non_scalar_op("multibroadcast"));
// If the broadcast is not a single dimension, then dont perform inner_broadcast
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment