Unverified Commit f6e22d56 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Require the same type for the inputs and scales for QuantizeLinear (#1642)

Converts can be inserted when the scales and input differ in the onnx file(we are already doing this implicit conversion in the ref implementation). This will also improve the compile-time of quantizelinear.hpp since we can remove the nested visit method.
parent adccec52
...@@ -148,10 +148,8 @@ shape common_shape(const std::vector<shape>& shapes) ...@@ -148,10 +148,8 @@ shape common_shape(const std::vector<shape>& shapes)
return {compute_common_types(shapes), compute_common_lens(shapes)}; return {compute_common_types(shapes), compute_common_lens(shapes)};
} }
instruction_ref insert_common_op(module& m, std::vector<instruction_ref>
instruction_ref ins, insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref> inputs)
const operation& op,
std::vector<instruction_ref> inputs)
{ {
if(std::any_of( if(std::any_of(
inputs.cbegin(), inputs.cend(), [](auto input) { return input->get_shape().dynamic(); })) inputs.cbegin(), inputs.cend(), [](auto input) { return input->get_shape().dynamic(); }))
...@@ -210,7 +208,20 @@ instruction_ref insert_common_op(module& m, ...@@ -210,7 +208,20 @@ instruction_ref insert_common_op(module& m,
return input; return input;
}); });
} }
return m.insert_instruction(ins, op, inputs); return inputs;
}
std::vector<instruction_ref> add_common_args(module& m, std::vector<instruction_ref> inputs)
{
return insert_common_args(m, m.end(), std::move(inputs));
}
instruction_ref insert_common_op(module& m,
instruction_ref ins,
const operation& op,
std::vector<instruction_ref> inputs)
{
return m.insert_instruction(ins, op, insert_common_args(m, ins, std::move(inputs)));
} }
instruction_ref add_common_op(module& m, const operation& op, std::vector<instruction_ref> inputs) instruction_ref add_common_op(module& m, const operation& op, std::vector<instruction_ref> inputs)
......
...@@ -41,6 +41,11 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha ...@@ -41,6 +41,11 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
shape common_shape(const std::vector<shape>& shapes); shape common_shape(const std::vector<shape>& shapes);
std::vector<instruction_ref>
insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref> inputs);
std::vector<instruction_ref> add_common_args(module& m, std::vector<instruction_ref> inputs);
instruction_ref insert_common_op(module& m, instruction_ref insert_common_op(module& m,
instruction_ref ins, instruction_ref ins,
const operation& op, const operation& op,
......
...@@ -40,7 +40,11 @@ struct dequantizelinear ...@@ -40,7 +40,11 @@ struct dequantizelinear
std::string name() const { return "dequantizelinear"; } std::string name() const { return "dequantizelinear"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.same_dims(); check_shapes{inputs, *this}.same_dims().has(2, 3);
if(inputs.size() == 3 and inputs[0].type() != inputs[2].type())
{
MIGRAPHX_THROW("DEQUANTIZELINEAR: Zero point and input should be the same type.");
}
return {inputs[1].type(), inputs[0].lens(), inputs[0].strides()}; return {inputs[1].type(), inputs[0].lens(), inputs[0].strides()};
} }
......
...@@ -40,7 +40,11 @@ struct quantizelinear ...@@ -40,7 +40,11 @@ struct quantizelinear
std::string name() const { return "quantizelinear"; } std::string name() const { return "quantizelinear"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.same_dims(); check_shapes{inputs, *this}.same_dims().has(2, 3);
if(inputs[0].type() != inputs[1].type())
{
MIGRAPHX_THROW("QUANTIZELINEAR: Scales and input must be the same type");
}
if(inputs.size() == 3) if(inputs.size() == 3)
{ {
return {inputs[2].type(), inputs[0].lens(), inputs[0].strides()}; return {inputs[2].type(), inputs[0].lens(), inputs[0].strides()};
...@@ -61,8 +65,7 @@ struct quantizelinear ...@@ -61,8 +65,7 @@ struct quantizelinear
argument result{output_shape}; argument result{output_shape};
visit_all(result, y_zero_point)([&](auto output, auto zero_pts) { visit_all(result, y_zero_point)([&](auto output, auto zero_pts) {
x.visit([&](auto input) { visit_all(x, y_scale)([&](auto input, auto scales) {
y_scale.visit([&](auto scales) {
using quant_type = typename decltype(output)::value_type; using quant_type = typename decltype(output)::value_type;
auto min_value = std::numeric_limits<quant_type>::min(); auto min_value = std::numeric_limits<quant_type>::min();
auto max_value = std::numeric_limits<quant_type>::max(); auto max_value = std::numeric_limits<quant_type>::max();
...@@ -74,7 +77,6 @@ struct quantizelinear ...@@ -74,7 +77,6 @@ struct quantizelinear
}); });
}); });
}); });
});
return result; return result;
} }
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp> #include <migraphx/tune_axis.hpp>
#include <migraphx/common.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -47,18 +48,15 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear> ...@@ -47,18 +48,15 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
auto input_lens = args[0]->get_shape().lens(); auto input_lens = args[0]->get_shape().lens();
auto n_dim = input_lens.size(); auto n_dim = input_lens.size();
instruction_ref y_scale; instruction_ref y_scale = args[1];
if(args[1]->get_shape().elements() != 1) if(args[1]->get_shape().elements() != 1)
{ {
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name); auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
y_scale = info.add_instruction( y_scale = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]); make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]);
} }
else
{ auto common_args = add_common_args(*info.mod, {args[0], y_scale});
y_scale = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
args[1]);
}
if(args.size() == 3) if(args.size() == 3)
{ {
...@@ -76,10 +74,10 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear> ...@@ -76,10 +74,10 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
make_op("multibroadcast", {{"out_lens", input_lens}}), y_zero_point); make_op("multibroadcast", {{"out_lens", input_lens}}), y_zero_point);
} }
return info.add_instruction(make_op("quantizelinear"), args[0], y_scale, y_zero_point); common_args.push_back(y_zero_point);
} }
return info.add_instruction(make_op("quantizelinear"), args[0], y_scale); return info.add_instruction(make_op("quantizelinear"), common_args);
} }
}; };
......
...@@ -2014,6 +2014,62 @@ TEST_CASE(quant_dot_2args) ...@@ -2014,6 +2014,62 @@ TEST_CASE(quant_dot_2args)
} }
} }
TEST_CASE(qlinear)
{
migraphx::shape scales{migraphx::shape::float_type, {2, 4}};
migraphx::shape input{migraphx::shape::float_type, {2, 4}};
migraphx::shape result{migraphx::shape::uint8_type, {2, 4}};
expect_shape(result, migraphx::make_op("quantizelinear"), input, scales);
}
TEST_CASE(qlinear_zeros)
{
migraphx::shape zeros{migraphx::shape::int8_type, {2, 4}};
migraphx::shape scales{migraphx::shape::float_type, {2, 4}};
migraphx::shape input{migraphx::shape::float_type, {2, 4}};
migraphx::shape result{migraphx::shape::int8_type, {2, 4}};
expect_shape(result, migraphx::make_op("quantizelinear"), input, scales, zeros);
}
TEST_CASE(qlinear_fp16)
{
migraphx::shape scales{migraphx::shape::half_type, {2, 4}};
migraphx::shape input{migraphx::shape::half_type, {2, 4}};
migraphx::shape result{migraphx::shape::uint8_type, {2, 4}};
expect_shape(result, migraphx::make_op("quantizelinear"), input, scales);
}
TEST_CASE(qlinear_mismatch_type)
{
migraphx::shape scales{migraphx::shape::int8_type, {2, 4}};
migraphx::shape input{migraphx::shape::float_type, {2, 4}};
throws_shape(migraphx::make_op("quantizelinear"), input, scales);
}
TEST_CASE(dqlinear)
{
migraphx::shape scales{migraphx::shape::float_type, {2, 4}};
migraphx::shape input{migraphx::shape::int8_type, {2, 4}};
migraphx::shape result{migraphx::shape::float_type, {2, 4}};
expect_shape(result, migraphx::make_op("dequantizelinear"), input, scales);
}
TEST_CASE(dqlinear_fp16)
{
migraphx::shape scales{migraphx::shape::half_type, {2, 4}};
migraphx::shape input{migraphx::shape::int8_type, {2, 4}};
migraphx::shape result{migraphx::shape::half_type, {2, 4}};
expect_shape(result, migraphx::make_op("dequantizelinear"), input, scales);
}
TEST_CASE(dqlinear_mismatch_type)
{
migraphx::shape zeros{migraphx::shape::float_type, {2, 4}};
migraphx::shape scales{migraphx::shape::float_type, {2, 4}};
migraphx::shape input{migraphx::shape::int8_type, {2, 4}};
throws_shape(migraphx::make_op("dequantizelinear"), input, scales, zeros);
}
template <class T> template <class T>
void test_reduce_ops() void test_reduce_ops()
{ {
......
...@@ -33,12 +33,20 @@ ...@@ -33,12 +33,20 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/verify.hpp>
bool is_quantizelinear(migraphx::instruction& ins) { return ins.name() == "quantizelinear"; } bool is_quantizelinear(migraphx::instruction& ins) { return ins.name() == "quantizelinear"; }
bool is_dequantizelinear(migraphx::instruction& ins) { return ins.name() == "dequantizelinear"; } bool is_dequantizelinear(migraphx::instruction& ins) { return ins.name() == "dequantizelinear"; }
void run_pass(migraphx::module& m) { migraphx::run_passes(m, {migraphx::rewrite_quantization{}}); }
migraphx::argument eval(const migraphx::program& p)
{
auto r = p.eval({});
EXPECT(r.size() == 1);
return r.front();
}
TEST_CASE(quantizelinear) TEST_CASE(quantizelinear)
{ {
...@@ -58,8 +66,8 @@ TEST_CASE(quantizelinear) ...@@ -58,8 +66,8 @@ TEST_CASE(quantizelinear)
migraphx::program p1 = create_program(); migraphx::program p1 = create_program();
migraphx::program p2 = create_program(); migraphx::program p2 = create_program();
migraphx::rewrite_quantization opt; run_pass(*p2.get_main_module());
opt.apply(*p2.get_main_module()); EXPECT(eval(p1) == eval(p2));
EXPECT(any_of(*p1.get_main_module(), &is_quantizelinear)); EXPECT(any_of(*p1.get_main_module(), &is_quantizelinear));
EXPECT(none_of(*p2.get_main_module(), &is_quantizelinear)); EXPECT(none_of(*p2.get_main_module(), &is_quantizelinear));
} }
...@@ -71,8 +79,8 @@ TEST_CASE(dequantizelinear) ...@@ -71,8 +79,8 @@ TEST_CASE(dequantizelinear)
std::vector<float> xv = {0, 1, 2, 5, 10, 50, 100, 150, 250}; std::vector<float> xv = {0, 1, 2, 5, 10, 50, 100, 150, 250};
migraphx::shape ss{migraphx::shape::float_type, {1, 3, 3}}; migraphx::shape ss{migraphx::shape::float_type, {1, 3, 3}};
std::vector<float> sv = {2, 2, 2, 2, 2, 2, 2, 2, 2}; std::vector<float> sv = {2, 2, 2, 2, 2, 2, 2, 2, 2};
migraphx::shape zs{migraphx::shape::uint8_type, {1, 3, 3}}; migraphx::shape zs{migraphx::shape::float_type, {1, 3, 3}};
std::vector<uint8_t> zv = {0, 0, 0, 0, 0, 0, 0, 0, 0}; std::vector<float> zv = {0, 0, 0, 0, 0, 0, 0, 0, 0};
auto create_program = [&]() { auto create_program = [&]() {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
...@@ -86,8 +94,8 @@ TEST_CASE(dequantizelinear) ...@@ -86,8 +94,8 @@ TEST_CASE(dequantizelinear)
migraphx::program p1 = create_program(); migraphx::program p1 = create_program();
migraphx::program p2 = create_program(); migraphx::program p2 = create_program();
migraphx::rewrite_quantization opt; run_pass(*p2.get_main_module());
opt.apply(*p2.get_main_module()); EXPECT(eval(p1) == eval(p2));
EXPECT(any_of(*p1.get_main_module(), &is_dequantizelinear)); EXPECT(any_of(*p1.get_main_module(), &is_dequantizelinear));
EXPECT(none_of(*p2.get_main_module(), &is_dequantizelinear)); EXPECT(none_of(*p2.get_main_module(), &is_dequantizelinear));
} }
......
...@@ -402,9 +402,10 @@ TEST_CASE(conv_bias_add) ...@@ -402,9 +402,10 @@ TEST_CASE(conv_bias_add)
auto bias = m1.add_parameter("bias", s6); auto bias = m1.add_parameter("bias", s6);
auto scale = m1.add_literal(0.5f); auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0}); auto zero = m1.add_literal(std::int8_t{0});
auto zero32 = m1.add_literal(std::int32_t{0});
auto d1 = add_quantize_op(m1, "dequantizelinear", weights, scale, zero); auto d1 = add_quantize_op(m1, "dequantizelinear", weights, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", bias, scale, zero); auto d2 = add_quantize_op(m1, "dequantizelinear", bias, scale, zero32);
auto q1 = add_quantize_op(m1, "quantizelinear", input, scale, zero); auto q1 = add_quantize_op(m1, "quantizelinear", input, scale, zero);
auto d5 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero); auto d5 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto c1 = m1.add_instruction(migraphx::make_op("convolution", auto c1 = m1.add_instruction(migraphx::make_op("convolution",
...@@ -428,9 +429,10 @@ TEST_CASE(conv_bias_add) ...@@ -428,9 +429,10 @@ TEST_CASE(conv_bias_add)
auto bias = m2.add_parameter("bias", s6); auto bias = m2.add_parameter("bias", s6);
auto scale = m2.add_literal(0.5f); auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0}); auto zero = m2.add_literal(std::int8_t{0});
auto zero32 = m2.add_literal(std::int32_t{0});
auto scale1 = m2.add_literal(0.25f); auto scale1 = m2.add_literal(0.25f);
auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero); auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero32);
auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero); auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution", auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution",
{{"padding", {0, 0, 0, 0}}, {{"padding", {0, 0, 0, 0}},
...@@ -468,9 +470,10 @@ TEST_CASE(conv_pooling_dot) ...@@ -468,9 +470,10 @@ TEST_CASE(conv_pooling_dot)
auto input = m1.add_parameter("input", s7); auto input = m1.add_parameter("input", s7);
auto scale = m1.add_literal(0.5f); auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0}); auto zero = m1.add_literal(std::int8_t{0});
auto zero32 = m1.add_literal(std::int32_t{0});
auto d1 = add_quantize_op(m1, "dequantizelinear", weights, scale, zero); auto d1 = add_quantize_op(m1, "dequantizelinear", weights, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", bias, scale, zero); auto d2 = add_quantize_op(m1, "dequantizelinear", bias, scale, zero32);
auto d3 = add_quantize_op(m1, "dequantizelinear", ab, scale, zero); auto d3 = add_quantize_op(m1, "dequantizelinear", ab, scale, zero);
auto d4 = add_quantize_op(m1, "dequantizelinear", db, scale, zero); auto d4 = add_quantize_op(m1, "dequantizelinear", db, scale, zero);
auto q1 = add_quantize_op(m1, "quantizelinear", input, scale, zero); auto q1 = add_quantize_op(m1, "quantizelinear", input, scale, zero);
...@@ -515,10 +518,11 @@ TEST_CASE(conv_pooling_dot) ...@@ -515,10 +518,11 @@ TEST_CASE(conv_pooling_dot)
auto input = m2.add_parameter("input", s7); auto input = m2.add_parameter("input", s7);
auto scale = m2.add_literal(0.5f); auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0}); auto zero = m2.add_literal(std::int8_t{0});
auto zero32 = m2.add_literal(std::int32_t{0});
auto scale1 = m2.add_literal(0.25f); auto scale1 = m2.add_literal(0.25f);
auto scale2 = m2.add_literal(0.25f); auto scale2 = m2.add_literal(0.25f);
auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero); auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero32);
auto d3 = add_quantize_op(m2, "dequantizelinear", ab, scale, zero); auto d3 = add_quantize_op(m2, "dequantizelinear", ab, scale, zero);
auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero); auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution", auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution",
...@@ -572,9 +576,10 @@ TEST_CASE(mobilenet_snippet) ...@@ -572,9 +576,10 @@ TEST_CASE(mobilenet_snippet)
auto input = mm.add_parameter("input", s7); auto input = mm.add_parameter("input", s7);
auto scale = mm.add_literal(0.5f); auto scale = mm.add_literal(0.5f);
auto zero = mm.add_literal(std::int8_t{0}); auto zero = mm.add_literal(std::int8_t{0});
auto zero32 = mm.add_literal(std::int32_t{0});
auto d1 = add_quantize_op(mm, "dequantizelinear", weights, scale, zero); auto d1 = add_quantize_op(mm, "dequantizelinear", weights, scale, zero);
auto d2 = add_quantize_op(mm, "dequantizelinear", bias, scale, zero); auto d2 = add_quantize_op(mm, "dequantizelinear", bias, scale, zero32);
auto d3 = add_quantize_op(mm, "dequantizelinear", ab, scale, zero); auto d3 = add_quantize_op(mm, "dequantizelinear", ab, scale, zero);
auto d4 = add_quantize_op(mm, "dequantizelinear", db, scale, zero); auto d4 = add_quantize_op(mm, "dequantizelinear", db, scale, zero);
auto q1 = add_quantize_op(mm, "quantizelinear", input, scale, zero); auto q1 = add_quantize_op(mm, "quantizelinear", input, scale, zero);
......
...@@ -40,7 +40,10 @@ struct test_quantizelinear_int32 : verify_program<test_quantizelinear_int32> ...@@ -40,7 +40,10 @@ struct test_quantizelinear_int32 : verify_program<test_quantizelinear_int32>
auto input1 = mm->add_parameter("x", sx); auto input1 = mm->add_parameter("x", sx);
auto input2 = mm->add_parameter("y_scale", ss); auto input2 = mm->add_parameter("y_scale", ss);
auto input3 = mm->add_parameter("y_zero_point", sz); auto input3 = mm->add_parameter("y_zero_point", sz);
auto r = mm->add_instruction(migraphx::make_op("quantizelinear"), input1, input2, input3); auto input1_float = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), input1);
auto r =
mm->add_instruction(migraphx::make_op("quantizelinear"), input1_float, input2, input3);
mm->add_return({r}); mm->add_return({r});
return p; return p;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment