Unverified Commit f6e22d56 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Require the same type for the inputs and scales for QuantizeLinear (#1642)

Converts can be inserted when the scales and input differ in the onnx file(we are already doing this implicit conversion in the ref implementation). This will also improve the compile-time of quantizelinear.hpp since we can remove the nested visit method.
parent adccec52
......@@ -148,10 +148,8 @@ shape common_shape(const std::vector<shape>& shapes)
return {compute_common_types(shapes), compute_common_lens(shapes)};
}
instruction_ref insert_common_op(module& m,
instruction_ref ins,
const operation& op,
std::vector<instruction_ref> inputs)
std::vector<instruction_ref>
insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref> inputs)
{
if(std::any_of(
inputs.cbegin(), inputs.cend(), [](auto input) { return input->get_shape().dynamic(); }))
......@@ -210,7 +208,20 @@ instruction_ref insert_common_op(module& m,
return input;
});
}
return m.insert_instruction(ins, op, inputs);
return inputs;
}
std::vector<instruction_ref> add_common_args(module& m, std::vector<instruction_ref> inputs)
{
return insert_common_args(m, m.end(), std::move(inputs));
}
instruction_ref insert_common_op(module& m,
instruction_ref ins,
const operation& op,
std::vector<instruction_ref> inputs)
{
return m.insert_instruction(ins, op, insert_common_args(m, ins, std::move(inputs)));
}
instruction_ref add_common_op(module& m, const operation& op, std::vector<instruction_ref> inputs)
......
......@@ -41,6 +41,11 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
shape common_shape(const std::vector<shape>& shapes);
std::vector<instruction_ref>
insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref> inputs);
std::vector<instruction_ref> add_common_args(module& m, std::vector<instruction_ref> inputs);
instruction_ref insert_common_op(module& m,
instruction_ref ins,
const operation& op,
......
......@@ -40,7 +40,11 @@ struct dequantizelinear
std::string name() const { return "dequantizelinear"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.same_dims();
check_shapes{inputs, *this}.same_dims().has(2, 3);
if(inputs.size() == 3 and inputs[0].type() != inputs[2].type())
{
MIGRAPHX_THROW("DEQUANTIZELINEAR: Zero point and input should be the same type.");
}
return {inputs[1].type(), inputs[0].lens(), inputs[0].strides()};
}
......
......@@ -40,7 +40,11 @@ struct quantizelinear
std::string name() const { return "quantizelinear"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.same_dims();
check_shapes{inputs, *this}.same_dims().has(2, 3);
if(inputs[0].type() != inputs[1].type())
{
MIGRAPHX_THROW("QUANTIZELINEAR: Scales and input must be the same type");
}
if(inputs.size() == 3)
{
return {inputs[2].type(), inputs[0].lens(), inputs[0].strides()};
......@@ -61,17 +65,15 @@ struct quantizelinear
argument result{output_shape};
visit_all(result, y_zero_point)([&](auto output, auto zero_pts) {
x.visit([&](auto input) {
y_scale.visit([&](auto scales) {
using quant_type = typename decltype(output)::value_type;
auto min_value = std::numeric_limits<quant_type>::min();
auto max_value = std::numeric_limits<quant_type>::max();
par_for(output_shape.elements(), [&](auto i) {
int64_t quantized = static_cast<int64_t>(std::round(input[i] / scales[i])) +
static_cast<int64_t>(zero_pts[i]);
output[i] = std::max(static_cast<int64_t>(min_value),
std::min(static_cast<int64_t>(max_value), quantized));
});
visit_all(x, y_scale)([&](auto input, auto scales) {
using quant_type = typename decltype(output)::value_type;
auto min_value = std::numeric_limits<quant_type>::min();
auto max_value = std::numeric_limits<quant_type>::max();
par_for(output_shape.elements(), [&](auto i) {
int64_t quantized = static_cast<int64_t>(std::round(input[i] / scales[i])) +
static_cast<int64_t>(zero_pts[i]);
output[i] = std::max(static_cast<int64_t>(min_value),
std::min(static_cast<int64_t>(max_value), quantized));
});
});
});
......
......@@ -26,6 +26,7 @@
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -47,18 +48,15 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
auto input_lens = args[0]->get_shape().lens();
auto n_dim = input_lens.size();
instruction_ref y_scale;
instruction_ref y_scale = args[1];
if(args[1]->get_shape().elements() != 1)
{
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
y_scale = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]);
}
else
{
y_scale = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
args[1]);
}
auto common_args = add_common_args(*info.mod, {args[0], y_scale});
if(args.size() == 3)
{
......@@ -76,10 +74,10 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
make_op("multibroadcast", {{"out_lens", input_lens}}), y_zero_point);
}
return info.add_instruction(make_op("quantizelinear"), args[0], y_scale, y_zero_point);
common_args.push_back(y_zero_point);
}
return info.add_instruction(make_op("quantizelinear"), args[0], y_scale);
return info.add_instruction(make_op("quantizelinear"), common_args);
}
};
......
......@@ -2014,6 +2014,62 @@ TEST_CASE(quant_dot_2args)
}
}
TEST_CASE(qlinear)
{
migraphx::shape scales{migraphx::shape::float_type, {2, 4}};
migraphx::shape input{migraphx::shape::float_type, {2, 4}};
migraphx::shape result{migraphx::shape::uint8_type, {2, 4}};
expect_shape(result, migraphx::make_op("quantizelinear"), input, scales);
}
TEST_CASE(qlinear_zeros)
{
migraphx::shape zeros{migraphx::shape::int8_type, {2, 4}};
migraphx::shape scales{migraphx::shape::float_type, {2, 4}};
migraphx::shape input{migraphx::shape::float_type, {2, 4}};
migraphx::shape result{migraphx::shape::int8_type, {2, 4}};
expect_shape(result, migraphx::make_op("quantizelinear"), input, scales, zeros);
}
TEST_CASE(qlinear_fp16)
{
migraphx::shape scales{migraphx::shape::half_type, {2, 4}};
migraphx::shape input{migraphx::shape::half_type, {2, 4}};
migraphx::shape result{migraphx::shape::uint8_type, {2, 4}};
expect_shape(result, migraphx::make_op("quantizelinear"), input, scales);
}
TEST_CASE(qlinear_mismatch_type)
{
migraphx::shape scales{migraphx::shape::int8_type, {2, 4}};
migraphx::shape input{migraphx::shape::float_type, {2, 4}};
throws_shape(migraphx::make_op("quantizelinear"), input, scales);
}
TEST_CASE(dqlinear)
{
migraphx::shape scales{migraphx::shape::float_type, {2, 4}};
migraphx::shape input{migraphx::shape::int8_type, {2, 4}};
migraphx::shape result{migraphx::shape::float_type, {2, 4}};
expect_shape(result, migraphx::make_op("dequantizelinear"), input, scales);
}
TEST_CASE(dqlinear_fp16)
{
migraphx::shape scales{migraphx::shape::half_type, {2, 4}};
migraphx::shape input{migraphx::shape::int8_type, {2, 4}};
migraphx::shape result{migraphx::shape::half_type, {2, 4}};
expect_shape(result, migraphx::make_op("dequantizelinear"), input, scales);
}
TEST_CASE(dqlinear_mismatch_type)
{
migraphx::shape zeros{migraphx::shape::float_type, {2, 4}};
migraphx::shape scales{migraphx::shape::float_type, {2, 4}};
migraphx::shape input{migraphx::shape::int8_type, {2, 4}};
throws_shape(migraphx::make_op("dequantizelinear"), input, scales, zeros);
}
template <class T>
void test_reduce_ops()
{
......
......@@ -33,12 +33,20 @@
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/pass_manager.hpp>
bool is_quantizelinear(migraphx::instruction& ins) { return ins.name() == "quantizelinear"; }
bool is_dequantizelinear(migraphx::instruction& ins) { return ins.name() == "dequantizelinear"; }
void run_pass(migraphx::module& m) { migraphx::run_passes(m, {migraphx::rewrite_quantization{}}); }
migraphx::argument eval(const migraphx::program& p)
{
auto r = p.eval({});
EXPECT(r.size() == 1);
return r.front();
}
TEST_CASE(quantizelinear)
{
......@@ -58,8 +66,8 @@ TEST_CASE(quantizelinear)
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
migraphx::rewrite_quantization opt;
opt.apply(*p2.get_main_module());
run_pass(*p2.get_main_module());
EXPECT(eval(p1) == eval(p2));
EXPECT(any_of(*p1.get_main_module(), &is_quantizelinear));
EXPECT(none_of(*p2.get_main_module(), &is_quantizelinear));
}
......@@ -71,9 +79,9 @@ TEST_CASE(dequantizelinear)
std::vector<float> xv = {0, 1, 2, 5, 10, 50, 100, 150, 250};
migraphx::shape ss{migraphx::shape::float_type, {1, 3, 3}};
std::vector<float> sv = {2, 2, 2, 2, 2, 2, 2, 2, 2};
migraphx::shape zs{migraphx::shape::uint8_type, {1, 3, 3}};
std::vector<uint8_t> zv = {0, 0, 0, 0, 0, 0, 0, 0, 0};
auto create_program = [&]() {
migraphx::shape zs{migraphx::shape::float_type, {1, 3, 3}};
std::vector<float> zv = {0, 0, 0, 0, 0, 0, 0, 0, 0};
auto create_program = [&]() {
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_literal(xs, xv);
......@@ -86,8 +94,8 @@ TEST_CASE(dequantizelinear)
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
migraphx::rewrite_quantization opt;
opt.apply(*p2.get_main_module());
run_pass(*p2.get_main_module());
EXPECT(eval(p1) == eval(p2));
EXPECT(any_of(*p1.get_main_module(), &is_dequantizelinear));
EXPECT(none_of(*p2.get_main_module(), &is_dequantizelinear));
}
......
......@@ -402,9 +402,10 @@ TEST_CASE(conv_bias_add)
auto bias = m1.add_parameter("bias", s6);
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto zero32 = m1.add_literal(std::int32_t{0});
auto d1 = add_quantize_op(m1, "dequantizelinear", weights, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", bias, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", bias, scale, zero32);
auto q1 = add_quantize_op(m1, "quantizelinear", input, scale, zero);
auto d5 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto c1 = m1.add_instruction(migraphx::make_op("convolution",
......@@ -428,9 +429,10 @@ TEST_CASE(conv_bias_add)
auto bias = m2.add_parameter("bias", s6);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto zero32 = m2.add_literal(std::int32_t{0});
auto scale1 = m2.add_literal(0.25f);
auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero);
auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero32);
auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution",
{{"padding", {0, 0, 0, 0}},
......@@ -468,9 +470,10 @@ TEST_CASE(conv_pooling_dot)
auto input = m1.add_parameter("input", s7);
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto zero32 = m1.add_literal(std::int32_t{0});
auto d1 = add_quantize_op(m1, "dequantizelinear", weights, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", bias, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", bias, scale, zero32);
auto d3 = add_quantize_op(m1, "dequantizelinear", ab, scale, zero);
auto d4 = add_quantize_op(m1, "dequantizelinear", db, scale, zero);
auto q1 = add_quantize_op(m1, "quantizelinear", input, scale, zero);
......@@ -515,10 +518,11 @@ TEST_CASE(conv_pooling_dot)
auto input = m2.add_parameter("input", s7);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto zero32 = m2.add_literal(std::int32_t{0});
auto scale1 = m2.add_literal(0.25f);
auto scale2 = m2.add_literal(0.25f);
auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero);
auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero32);
auto d3 = add_quantize_op(m2, "dequantizelinear", ab, scale, zero);
auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution",
......@@ -572,9 +576,10 @@ TEST_CASE(mobilenet_snippet)
auto input = mm.add_parameter("input", s7);
auto scale = mm.add_literal(0.5f);
auto zero = mm.add_literal(std::int8_t{0});
auto zero32 = mm.add_literal(std::int32_t{0});
auto d1 = add_quantize_op(mm, "dequantizelinear", weights, scale, zero);
auto d2 = add_quantize_op(mm, "dequantizelinear", bias, scale, zero);
auto d2 = add_quantize_op(mm, "dequantizelinear", bias, scale, zero32);
auto d3 = add_quantize_op(mm, "dequantizelinear", ab, scale, zero);
auto d4 = add_quantize_op(mm, "dequantizelinear", db, scale, zero);
auto q1 = add_quantize_op(mm, "quantizelinear", input, scale, zero);
......
......@@ -37,10 +37,13 @@ struct test_quantizelinear_int32 : verify_program<test_quantizelinear_int32>
migraphx::shape sx{migraphx::shape::int32_type, {2, 2, 2}};
migraphx::shape ss{migraphx::shape::float_type, {2, 2, 2}};
migraphx::shape sz{migraphx::shape::int8_type, {2, 2, 2}};
auto input1 = mm->add_parameter("x", sx);
auto input2 = mm->add_parameter("y_scale", ss);
auto input3 = mm->add_parameter("y_zero_point", sz);
auto r = mm->add_instruction(migraphx::make_op("quantizelinear"), input1, input2, input3);
auto input1 = mm->add_parameter("x", sx);
auto input2 = mm->add_parameter("y_scale", ss);
auto input3 = mm->add_parameter("y_zero_point", sz);
auto input1_float = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), input1);
auto r =
mm->add_instruction(migraphx::make_op("quantizelinear"), input1_float, input2, input3);
mm->add_return({r});
return p;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment