Unverified Commit 3282e01a authored by turneram's avatar turneram Committed by GitHub
Browse files

Quantize linear ops (#843)

* Add operators, refactor parsers, add rewrite passes, add tests

* Formatting

* Fix cppcheck

* Review comments

* Formatting

* Combine rewrite passes

* Formatting

* Add ref implementations

* Formatting

* Review comments

* Formatting

* Tidy warnings

* Apply review comments

* Formatting

* Fix CI error

* Formatting

* Increase code coverage

* Formatting

* Move broadcasting of scales and zero points to onnx parser

* Formatting

* Allow for x and zero_point to have different types in quantizelinear; fix zero_point default type

* Formatting

* Increase code coverage

* Formatting

* Switch certain variables to int64_t

* Formatting

* Fix overflow in implicit constant conversion

* Formatting

* Increase code coverage

* Formatting

* Remove operators.hpp from includes in tf_test.cpp

* Formatting

* Add conversion for int32 input to quantizelinear and add test case; remove operators.hpp from onnx_test.cpp includes

* Formatting

* Switch dequantizelinear math from int32 to float

* Formatting

* Remove changes to operators.hpp

* Simplify apply_quantizelinear

* Formatting

* Add verify test for int32 data

* Add rewrite_quantization back to CMakeLists
parent 4983fecd
...@@ -53,6 +53,7 @@ add_library(migraphx ...@@ -53,6 +53,7 @@ add_library(migraphx
remap.cpp remap.cpp
rewrite_batchnorm.cpp rewrite_batchnorm.cpp
rewrite_pooling.cpp rewrite_pooling.cpp
rewrite_quantization.cpp
rewrite_rnn.cpp rewrite_rnn.cpp
schedule.cpp schedule.cpp
serialize.cpp serialize.cpp
...@@ -94,6 +95,7 @@ register_migraphx_ops( ...@@ -94,6 +95,7 @@ register_migraphx_ops(
cosh cosh
cos cos
deconvolution deconvolution
dequantizelinear
div div
dot dot
elu elu
...@@ -132,6 +134,7 @@ register_migraphx_ops( ...@@ -132,6 +134,7 @@ register_migraphx_ops(
prelu prelu
quant_convolution quant_convolution
quant_dot quant_dot
quantizelinear
recip recip
reduce_max reduce_max
reduce_mean reduce_mean
......
#ifndef MIGRAPHX_GUARD_OPERATORS_DEQUANTIZE_LINEAR_HPP
#define MIGRAPHX_GUARD_OPERATORS_DEQUANTIZE_LINEAR_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/config.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/tune_axis.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct dequantizelinear
{
std::string name() const { return "dequantizelinear"; }
shape compute_shape(std::vector<shape> inputs) const
{
return {shape::float_type, inputs[0].lens(), inputs[0].strides()};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
auto x = args.at(0);
auto x_scale = args.at(1);
std::vector<int8_t> zeros(output_shape.elements(), 0);
argument x_zero_point{{x.get_shape().type(), output_shape.lens()}, zeros.data()};
if(args.size() == 3)
{
x_zero_point = args.at(2);
}
argument result{output_shape};
visit_all(x, x_zero_point)([&](auto input, auto zero_pts) {
visit_all(result, x_scale)([&](auto output, auto scales) {
par_for(output_shape.elements(), [&](auto i) {
output[i] = static_cast<double>(static_cast<int64_t>(input[i]) -
static_cast<int64_t>(zero_pts[i])) *
scales[i];
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_QUANTIZE_LINEAR_HPP
#define MIGRAPHX_GUARD_OPERATORS_QUANTIZE_LINEAR_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/config.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/tune_axis.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct quantizelinear
{
std::string name() const { return "quantizelinear"; }
shape compute_shape(std::vector<shape> inputs) const
{
if(inputs.size() == 3)
{
return {inputs[2].type(), inputs[0].lens(), inputs[0].strides()};
}
return {shape::uint8_type, inputs[0].lens(), inputs[0].strides()};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
auto x = args.at(0);
auto y_scale = args.at(1);
std::vector<int8_t> zeros(output_shape.elements(), 0);
argument y_zero_point{output_shape, zeros.data()};
if(args.size() == 3)
{
y_zero_point = args.at(2);
}
argument result{output_shape};
visit_all(result, y_zero_point)([&](auto output, auto zero_pts) {
x.visit([&](auto input) {
y_scale.visit([&](auto scales) {
using quant_type = typename decltype(output)::value_type;
auto min_value = std::numeric_limits<quant_type>::min();
auto max_value = std::numeric_limits<quant_type>::max();
par_for(output_shape.elements(), [&](auto i) {
int64_t quantized = static_cast<int64_t>(std::round(input[i] / scales[i])) +
static_cast<int64_t>(zero_pts[i]);
output[i] = std::max(static_cast<int64_t>(min_value),
std::min(static_cast<int64_t>(max_value), quantized));
});
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_QUANTIZATION_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_QUANTIZATION_HPP
#include <string>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
* Rewrite quantization ops to equivalent operators
*/
struct rewrite_quantization
{
std::string name() const { return "rewrite_quantization"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -15,46 +15,49 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear> ...@@ -15,46 +15,49 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear>
instruction_ref parse(const op_desc& opd, instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/, const onnx_parser& /*parser*/,
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const const std::vector<instruction_ref>& args) const
{ {
int axis = 1; int axis = 1;
if(contains(info.attributes, "axis")) if(contains(info.attributes, "axis"))
axis = info.attributes.at("axis").i(); axis = info.attributes.at("axis").i();
auto input_lens = args[0]->get_shape().lens(); auto input_lens = args[0]->get_shape().lens();
int n_dim = static_cast<int>(input_lens.size()); auto n_dim = input_lens.size();
auto sub_zero_point = args[0]; instruction_ref x_scale;
if(args[1]->get_shape().elements() != 1)
{
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
x_scale = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}), args[1]);
}
else
{
x_scale = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
args[1]);
}
if(args.size() == 3) if(args.size() == 3)
{ {
auto zero_point = args[2]; auto x_zero_point = args[2];
if(not(zero_point->get_shape().elements() == 1)) if(x_zero_point->get_shape().elements() != 1)
{ {
axis = tune_axis(n_dim, axis, opd.op_name); auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
zero_point = info.add_instruction( x_zero_point = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), zero_point); make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}),
x_zero_point);
}
else
{
x_zero_point = info.add_instruction(
make_op("multibroadcast", {{"output_lens", input_lens}}), x_zero_point);
} }
auto zero_point_int32 = info.add_instruction( return info.add_instruction(
make_op("convert", {{"target_type", shape::int32_type}}), zero_point); make_op("dequantizelinear"), args[0], x_scale, x_zero_point);
auto sub_zero_point_int32 = info.add_instruction(
make_op("convert", {{"target_type", shape::int32_type}}), sub_zero_point);
sub_zero_point =
info.add_broadcastable_binary_op("sub", sub_zero_point_int32, zero_point_int32);
} }
auto dequant_input = info.add_instruction( return info.add_instruction(make_op("dequantizelinear"), args[0], x_scale);
make_op("convert", {{"target_type", shape::float_type}}), sub_zero_point);
auto scale = args[1];
if(not(scale->get_shape().elements() == 1))
{
axis = tune_axis(n_dim, axis, opd.op_name);
scale = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), scale);
}
return info.add_broadcastable_binary_op("mul", dequant_input, scale);
} }
}; };
......
...@@ -12,83 +12,51 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear> ...@@ -12,83 +12,51 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
{ {
std::vector<op_desc> operators() const { return {{"QuantizeLinear"}}; } std::vector<op_desc> operators() const { return {{"QuantizeLinear"}}; }
// y = saturate(round(x / y_scale) + zero_point)
instruction_ref parse(const op_desc& opd, instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/, const onnx_parser& /*parser*/,
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const const std::vector<instruction_ref>& args) const
{ {
auto quant_type = shape::uint8_type;
int nargs = args.size();
int max_quant = 255;
int min_quant = 0;
if(nargs == 3)
quant_type = args[2]->get_shape().type();
if(quant_type == shape::int8_type)
{
max_quant = 127;
min_quant = -128;
}
auto max_arg = info.add_literal(max_quant);
auto min_arg = info.add_literal(min_quant);
int axis = 1; int axis = 1;
if(contains(info.attributes, "axis")) if(contains(info.attributes, "axis"))
axis = info.attributes.at("axis").i(); axis = info.attributes.at("axis").i();
auto input_lens = args[0]->get_shape().lens(); auto input_lens = args[0]->get_shape().lens();
int n_dim = static_cast<int>(input_lens.size()); auto n_dim = input_lens.size();
auto scale = args[1]; instruction_ref y_scale;
if(not(scale->get_shape().elements() == 1)) if(args[1]->get_shape().elements() != 1)
{ {
axis = tune_axis(n_dim, axis, opd.op_name); auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
scale = info.add_instruction( y_scale = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), scale); make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}), args[1]);
}
else
{
y_scale = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
args[1]);
} }
auto div = info.add_broadcastable_binary_op("div", args[0], scale); if(args.size() == 3)
auto div_round = info.add_instruction(make_op("round"), div);
auto add_zero_point = div_round;
if(nargs == 3)
{ {
auto zero_point = args[2]; auto y_zero_point = args[2];
if(not(zero_point->get_shape().elements() == 1)) if(y_zero_point->get_shape().elements() != 1)
{ {
axis = tune_axis(n_dim, axis, opd.op_name); auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
zero_point = info.add_instruction( y_zero_point = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), zero_point); make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}),
y_zero_point);
}
else
{
y_zero_point = info.add_instruction(
make_op("multibroadcast", {{"output_lens", input_lens}}), y_zero_point);
} }
zero_point = info.add_instruction( return info.add_instruction(make_op("quantizelinear"), args[0], y_scale, y_zero_point);
make_op("convert", {{"target_type", shape::int32_type}}), zero_point);
add_zero_point = info.add_instruction(
make_op("convert", {{"target_type", shape::int32_type}}), add_zero_point);
add_zero_point = info.add_broadcastable_binary_op("add", add_zero_point, zero_point);
}
auto s = add_zero_point->get_shape();
const auto& lens = s.lens();
std::vector<int64_t> out_lens(lens.begin(), lens.end());
if(min_arg->get_shape() != s)
{
min_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}),
min_arg);
}
if(max_arg->get_shape() != s)
{
max_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}),
max_arg);
} }
auto saturated = info.add_instruction(make_op("clip"), add_zero_point, min_arg, max_arg); return info.add_instruction(make_op("quantizelinear"), args[0], y_scale);
return info.add_instruction(make_op("convert", {{"target_type", quant_type}}), saturated);
} }
}; };
......
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void apply_quantizelinear(module& m, instruction_ref ins)
{
assert(ins->name() == "quantizelinear");
auto x = ins->inputs()[0];
auto y_scale = ins->inputs()[1];
if(x->get_shape().type() != y_scale->get_shape().type())
{
x = m.insert_instruction(ins, make_op("convert", {{"target_type", shape::float_type}}), x);
}
auto div = m.insert_instruction(ins, make_op("div"), x, y_scale);
auto add_zero_point = m.insert_instruction(ins, make_op("round"), div);
if(ins->inputs().size() == 3)
{
auto zero_point = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[2]);
add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point);
}
int64_t max_quant = 0;
int64_t min_quant = 0;
ins->get_shape().visit_type([&](auto qt) {
max_quant = qt.max();
min_quant = qt.min();
});
auto s = add_zero_point->get_shape();
std::vector<int> min_data(s.elements(), min_quant);
std::vector<int> max_data(s.elements(), max_quant);
auto min_arg = m.add_literal(literal(s, min_data));
auto max_arg = m.add_literal(literal(s, max_data));
auto saturate = m.insert_instruction(ins, make_op("clip"), add_zero_point, min_arg, max_arg);
m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate);
}
void apply_dequantizelinear(module& m, instruction_ref ins)
{
assert(ins->name() == "dequantizelinear");
auto x = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[0]);
auto x_scale = ins->inputs()[1];
if(ins->inputs().size() == 3)
{
auto x_zero_point = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[2]);
x = m.insert_instruction(ins, make_op("sub"), x, x_zero_point);
}
m.replace_instruction(ins, make_op("mul"), x, x_scale);
}
void rewrite_quantization::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
if(ins->name() == "quantizelinear")
{
apply_quantizelinear(m, ins);
}
else if(ins->name() == "dequantizelinear")
{
apply_dequantizelinear(m, ins);
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <migraphx/remap.hpp> #include <migraphx/remap.hpp>
#include <migraphx/rewrite_batchnorm.hpp> #include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_pooling.hpp> #include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
#include <migraphx/schedule.hpp> #include <migraphx/schedule.hpp>
#include <migraphx/memory_coloring.hpp> #include <migraphx/memory_coloring.hpp>
...@@ -46,6 +47,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -46,6 +47,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end()); std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end());
unsupported_types.erase(shape::type_t::float_type); unsupported_types.erase(shape::type_t::float_type);
return {normalize_ops{}, return {normalize_ops{},
rewrite_quantization{},
dead_code_elimination{},
eliminate_data_type{unsupported_types, shape::type_t::float_type}, eliminate_data_type{unsupported_types, shape::type_t::float_type},
dead_code_elimination{}, dead_code_elimination{},
decompose{}, decompose{},
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <migraphx/remap.hpp> #include <migraphx/remap.hpp>
#include <migraphx/rewrite_batchnorm.hpp> #include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_pooling.hpp> #include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
#include <migraphx/schedule.hpp> #include <migraphx/schedule.hpp>
#include <migraphx/simplify_algebra.hpp> #include <migraphx/simplify_algebra.hpp>
...@@ -59,6 +60,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -59,6 +60,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
normalize_ops{}, normalize_ops{},
decompose{}, decompose{},
dead_code_elimination{}, dead_code_elimination{},
rewrite_quantization{},
dead_code_elimination{},
eliminate_data_type{unsupported_types, shape::type_t::float_type}, eliminate_data_type{unsupported_types, shape::type_t::float_type},
simplify_reshapes{}, simplify_reshapes{},
eliminate_identity{}, eliminate_identity{},
......
...@@ -23,4 +23,4 @@ ...@@ -23,4 +23,4 @@
 
 
 
B B
\ No newline at end of file \ No newline at end of file
...@@ -23,4 +23,4 @@ ...@@ -23,4 +23,4 @@
 
 
 
B B
\ No newline at end of file \ No newline at end of file
dequantizelinear_test: dequantizelinear_test:k

0 0
1 1out"DequantizeLineardequantizelinear_testZ
2out"DequantizeLineardequantizelinear_testZ
0 0
 
...@@ -10,12 +9,8 @@ ...@@ -10,12 +9,8 @@
1 1
 
Z
2

b b
out out
 
B B
\ No newline at end of file \ No newline at end of file
 dequantizelinear_zero_point_test:
0
1
2out"DequantizeLinear dequantizelinear_zero_point_testZ
0

Z
1

Z
2

b
out

B
\ No newline at end of file
...@@ -1021,6 +1021,21 @@ def deconv_stride_test(): ...@@ -1021,6 +1021,21 @@ def deconv_stride_test():
@onnx_test @onnx_test
def dequantizelinear_test(): def dequantizelinear_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.INT8, [5])
arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1])
arg_out = helper.make_tensor_value_info('out', TensorProto.FLOAT, [5])
node = onnx.helper.make_node(
'DequantizeLinear',
inputs=['0', '1'],
outputs=['out'],
)
return ([node], [arg0, arg1], [arg_out])
@onnx_test
def dequantizelinear_zero_point_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.INT8, [5]) arg0 = helper.make_tensor_value_info('0', TensorProto.INT8, [5])
arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1]) arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1])
arg2 = helper.make_tensor_value_info('2', TensorProto.INT8, [1]) arg2 = helper.make_tensor_value_info('2', TensorProto.INT8, [1])
...@@ -2751,6 +2766,36 @@ def prelu_brcst_test(): ...@@ -2751,6 +2766,36 @@ def prelu_brcst_test():
@onnx_test @onnx_test
def quantizelinear_test(): def quantizelinear_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5])
arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1])
arg_out = helper.make_tensor_value_info('out', TensorProto.INT8, [5])
node = onnx.helper.make_node(
'QuantizeLinear',
inputs=['0', '1'],
outputs=['out'],
)
return ([node], [arg0, arg1], [arg_out])
@onnx_test
def quantizelinear_int32_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.INT32, [5])
arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1])
arg_out = helper.make_tensor_value_info('out', TensorProto.INT8, [5])
node = onnx.helper.make_node(
'QuantizeLinear',
inputs=['0', '1'],
outputs=['out'],
)
return ([node], [arg0, arg1], [arg_out])
@onnx_test
def quantizelinear_zero_point_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5]) arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5])
arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1]) arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1])
arg2 = helper.make_tensor_value_info('2', TensorProto.INT8, [1]) arg2 = helper.make_tensor_value_info('2', TensorProto.INT8, [1])
......
...@@ -2,28 +2,35 @@ ...@@ -2,28 +2,35 @@
#include <fstream> #include <fstream>
#include <vector> #include <vector>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/eliminate_identity.hpp> #include <migraphx/eliminate_identity.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include "test.hpp" #include "test.hpp"
migraphx::program optimize_onnx(const std::string& name, bool eliminate_deadcode = false) migraphx::program optimize_onnx(const std::string& name, bool run_passes = false)
{ {
migraphx::onnx_options options; migraphx::onnx_options options;
options.skip_unknown_operators = true; options.skip_unknown_operators = true;
auto prog = migraphx::parse_onnx(name, options); auto prog = migraphx::parse_onnx(name, options);
auto* mm = prog.get_main_module(); auto* mm = prog.get_main_module();
if(eliminate_deadcode) if(run_passes)
migraphx::run_passes(*mm, {migraphx::dead_code_elimination{}}); migraphx::run_passes(*mm,
{migraphx::rewrite_quantization{}, migraphx::dead_code_elimination{}});
// remove the last identity instruction // remove the last identity instruction
auto last_ins = std::prev(mm->end()); auto last_ins = std::prev(mm->end());
...@@ -914,29 +921,42 @@ TEST_CASE(dequantizelinear_test) ...@@ -914,29 +921,42 @@ TEST_CASE(dequantizelinear_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::int8_type, {5}}); auto l0 = mm->add_parameter("0", {migraphx::shape::int8_type, {5}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}}); auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {1}});
auto l1_mbcast = auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1);
l2 = mm->add_instruction( auto dequant = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l2); l0);
mm->add_instruction(migraphx::make_op("mul"), dequant, l1_mbcast);
auto prog = optimize_onnx("dequantizelinear_test.onnx", true);
EXPECT(p.sort() == prog.sort());
}
TEST_CASE(dequantizelinear_zero_point_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::int8_type, {5}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {1}});
auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1);
auto l2_mbcast = auto l2_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l2); mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l2);
l0 = mm->add_instruction( l2_mbcast = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l0); l2_mbcast);
l0 = mm->add_instruction(
auto sub = mm->add_instruction(migraphx::make_op("sub"), l0, l2_mbcast);
auto dequant = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
sub); l0);
mm->add_instruction(migraphx::make_op("mul"), dequant, l1_mbcast); auto sub = mm->add_instruction(migraphx::make_op("sub"), l0, l2_mbcast);
mm->add_instruction(migraphx::make_op("mul"), sub, l1_mbcast);
auto prog = optimize_onnx("dequantizelinear_test.onnx"); auto prog = optimize_onnx("dequantizelinear_zero_point_test.onnx", true);
EXPECT(p.sort() == prog.sort()); EXPECT(p.sort() == prog.sort());
} }
...@@ -955,19 +975,15 @@ migraphx::program make_dequantizelinear_axis_prog() ...@@ -955,19 +975,15 @@ migraphx::program make_dequantizelinear_axis_prog()
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), l2); migraphx::make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), l2);
l2_bcast = mm->add_instruction( l2_bcast = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l2_bcast); l2_bcast);
l0 = mm->add_instruction( l0 = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
l0);
auto sub = mm->add_instruction(migraphx::make_op("sub"), l0, l2_bcast);
auto dequant = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
sub); l0);
auto sub = mm->add_instruction(migraphx::make_op("sub"), l0, l2_bcast);
mm->add_instruction(migraphx::make_op("mul"), dequant, l1_bcast); mm->add_instruction(migraphx::make_op("mul"), sub, l1_bcast);
return p; return p;
} }
...@@ -975,7 +991,7 @@ TEST_CASE(dequantizelinear_axis_test) ...@@ -975,7 +991,7 @@ TEST_CASE(dequantizelinear_axis_test)
{ {
migraphx::program p = make_dequantizelinear_axis_prog(); migraphx::program p = make_dequantizelinear_axis_prog();
auto prog = optimize_onnx("dequantizelinear_axis_test.onnx"); auto prog = optimize_onnx("dequantizelinear_axis_test.onnx", true);
EXPECT(p.sort() == prog.sort()); EXPECT(p.sort() == prog.sort());
} }
...@@ -983,7 +999,7 @@ TEST_CASE(dequantizelinear_neg_axis_test) ...@@ -983,7 +999,7 @@ TEST_CASE(dequantizelinear_neg_axis_test)
{ {
migraphx::program p = make_dequantizelinear_axis_prog(); migraphx::program p = make_dequantizelinear_axis_prog();
auto prog = optimize_onnx("dequantizelinear_neg_axis_test.onnx"); auto prog = optimize_onnx("dequantizelinear_neg_axis_test.onnx", true);
EXPECT(p.sort() == prog.sort()); EXPECT(p.sort() == prog.sort());
} }
...@@ -2378,41 +2394,87 @@ TEST_CASE(prelu_brcst_test) ...@@ -2378,41 +2394,87 @@ TEST_CASE(prelu_brcst_test)
TEST_CASE(quantizelinear_test) TEST_CASE(quantizelinear_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {5}}); auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {5}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}}); auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {1}});
auto min_val = mm->add_literal(-128);
auto max_val = mm->add_literal(127);
auto l1_mbcast = auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div);
auto s = round->get_shape();
std::vector<int> min_data(s.elements(), 0);
std::vector<int> max_data(s.elements(), 255);
auto min_arg = mm->add_literal(s, min_data);
auto max_arg = mm->add_literal(s, max_data);
auto clip = mm->add_instruction(migraphx::make_op("clip"), round, min_arg, max_arg);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::uint8_type)}}),
clip);
auto prog = optimize_onnx("quantizelinear_test.onnx", true);
EXPECT(p.sort() == prog.sort());
}
TEST_CASE(quantizelinear_int32_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::int32_type, {5}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1);
l0 = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l0);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast); auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div); auto round = mm->add_instruction(migraphx::make_op("round"), div);
l2 = mm->add_instruction( auto s = round->get_shape();
std::vector<int> min_data(s.elements(), 0);
std::vector<int> max_data(s.elements(), 255);
auto min_arg = mm->add_literal(s, min_data);
auto max_arg = mm->add_literal(s, max_data);
auto clip = mm->add_instruction(migraphx::make_op("clip"), round, min_arg, max_arg);
mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::uint8_type)}}),
l2); clip);
auto prog = optimize_onnx("quantizelinear_int32_test.onnx", true);
EXPECT(p.sort() == prog.sort());
}
TEST_CASE(quantizelinear_zero_point_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {5}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {1}});
auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div);
auto l2_mbcast = auto l2_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l2); mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l2);
l2_mbcast = mm->add_instruction(
round = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
round); l2_mbcast);
auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_mbcast); auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_mbcast);
min_val = auto s = round->get_shape();
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), min_val); std::vector<int> min_data(s.elements(), -128);
max_val = std::vector<int> max_data(s.elements(), 127);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), max_val); auto min_arg = mm->add_literal(s, min_data);
auto clip = mm->add_instruction(migraphx::make_op("clip"), add, min_val, max_val); auto max_arg = mm->add_literal(s, max_data);
auto clip = mm->add_instruction(migraphx::make_op("clip"), add, min_arg, max_arg);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
clip); clip);
auto prog = optimize_onnx("quantizelinear_test.onnx"); auto prog = optimize_onnx("quantizelinear_zero_point_test.onnx", true);
EXPECT(p.sort() == prog.sort()); EXPECT(p.sort() == prog.sort());
} }
...@@ -2423,12 +2485,9 @@ migraphx::program make_quantizelinear_axis_prog() ...@@ -2423,12 +2485,9 @@ migraphx::program make_quantizelinear_axis_prog()
int axis = 2; int axis = 2;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, input_lens}); auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, input_lens});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {5}}); auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {5}});
auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {5}}); auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {5}});
auto min_val = mm->add_literal(-128);
auto max_val = mm->add_literal(127);
auto l1_bcast = mm->add_instruction( auto l1_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), l1); migraphx::make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), l1);
...@@ -2438,18 +2497,15 @@ migraphx::program make_quantizelinear_axis_prog() ...@@ -2438,18 +2497,15 @@ migraphx::program make_quantizelinear_axis_prog()
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), l2); migraphx::make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), l2);
l2_bcast = mm->add_instruction( l2_bcast = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l2_bcast); l2_bcast);
round = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
round);
auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_bcast); auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_bcast);
min_val = mm->add_instruction( auto s = round->get_shape();
migraphx::make_op("multibroadcast", {{"output_lens", {1, 1, 5, 1}}}), min_val); std::vector<int> min_data(s.elements(), -128);
max_val = mm->add_instruction( std::vector<int> max_data(s.elements(), 127);
migraphx::make_op("multibroadcast", {{"output_lens", {1, 1, 5, 1}}}), max_val); auto min_arg = mm->add_literal(s, min_data);
auto clip = mm->add_instruction(migraphx::make_op("clip"), add, min_val, max_val); auto max_arg = mm->add_literal(s, max_data);
auto clip = mm->add_instruction(migraphx::make_op("clip"), add, min_arg, max_arg);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
...@@ -2461,7 +2517,7 @@ TEST_CASE(quantizelinear_axis_test) ...@@ -2461,7 +2517,7 @@ TEST_CASE(quantizelinear_axis_test)
{ {
migraphx::program p = make_quantizelinear_axis_prog(); migraphx::program p = make_quantizelinear_axis_prog();
auto prog = optimize_onnx("quantizelinear_axis_test.onnx"); auto prog = optimize_onnx("quantizelinear_axis_test.onnx", true);
EXPECT(p.sort() == prog.sort()); EXPECT(p.sort() == prog.sort());
} }
...@@ -2469,7 +2525,7 @@ TEST_CASE(quantizelinear_neg_axis_test) ...@@ -2469,7 +2525,7 @@ TEST_CASE(quantizelinear_neg_axis_test)
{ {
migraphx::program p = make_quantizelinear_axis_prog(); migraphx::program p = make_quantizelinear_axis_prog();
auto prog = optimize_onnx("quantizelinear_neg_axis_test.onnx"); auto prog = optimize_onnx("quantizelinear_neg_axis_test.onnx", true);
EXPECT(p.sort() == prog.sort()); EXPECT(p.sort() == prog.sort());
} }
......
...@@ -23,4 +23,4 @@ ...@@ -23,4 +23,4 @@
 
 
 
B B
\ No newline at end of file \ No newline at end of file
quantizelinear_int32_test:m

0
1out"QuantizeLinearquantizelinear_int32_testZ
0

Z
1

b
out

B
\ No newline at end of file
...@@ -23,4 +23,4 @@ ...@@ -23,4 +23,4 @@
 
 
 
B B
\ No newline at end of file \ No newline at end of file
quantizelinear_test:{ quantizelinear_test:g
 
0 0
1 1out"QuantizeLinearquantizelinear_testZ
2out"QuantizeLinearquantizelinear_testZ
0 0
 
...@@ -10,12 +9,8 @@ ...@@ -10,12 +9,8 @@
1 1
 
Z
2

b b
out out
 
B B
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment