Unverified Commit 3282e01a authored by turneram's avatar turneram Committed by GitHub
Browse files

Quantize linear ops (#843)

* Add operators, refactor parsers, add rewrite passes, add tests

* Formatting

* Fix cppcheck

* Review comments

* Formatting

* Combine rewrite passes

* Formatting

* Add ref implementations

* Formatting

* Review comments

* Formatting

* Tidy warnings

* Apply review comments

* Formatting

* Fix CI error

* Formatting

* Increase code coverage

* Formatting

* Move broadcasting of scales and zero points to onnx parser

* Formatting

* Allow for x and zero_point to have different types in quantizelinear; fix zero_point default type

* Formatting

* Increase code coverage

* Formatting

* Switch certain variables to int64_t

* Formatting

* Fix overflow in implicit constant conversion

* Formatting

* Increase code coverage

* Formatting

* Remove operators.hpp from includes in tf_test.cpp

* Formatting

* Add conversion for int32 input to quantizelinear and add test case; remove operators.hpp from onnx_test.cpp includes

* Formatting

* Switch dequantizelinear math from int32 to float

* Formatting

* Remove changes to operators.hpp

* Simplify apply_quantizelinear

* Formatting

* Add verify test for int32 data

* Add rewrite_quantization back to CMakeLists
parent 4983fecd
......@@ -53,6 +53,7 @@ add_library(migraphx
remap.cpp
rewrite_batchnorm.cpp
rewrite_pooling.cpp
rewrite_quantization.cpp
rewrite_rnn.cpp
schedule.cpp
serialize.cpp
......@@ -94,6 +95,7 @@ register_migraphx_ops(
cosh
cos
deconvolution
dequantizelinear
div
dot
elu
......@@ -132,6 +134,7 @@ register_migraphx_ops(
prelu
quant_convolution
quant_dot
quantizelinear
recip
reduce_max
reduce_mean
......
#ifndef MIGRAPHX_GUARD_OPERATORS_DEQUANTIZE_LINEAR_HPP
#define MIGRAPHX_GUARD_OPERATORS_DEQUANTIZE_LINEAR_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/config.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/tune_axis.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct dequantizelinear
{
std::string name() const { return "dequantizelinear"; }
shape compute_shape(std::vector<shape> inputs) const
{
return {shape::float_type, inputs[0].lens(), inputs[0].strides()};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
auto x = args.at(0);
auto x_scale = args.at(1);
std::vector<int8_t> zeros(output_shape.elements(), 0);
argument x_zero_point{{x.get_shape().type(), output_shape.lens()}, zeros.data()};
if(args.size() == 3)
{
x_zero_point = args.at(2);
}
argument result{output_shape};
visit_all(x, x_zero_point)([&](auto input, auto zero_pts) {
visit_all(result, x_scale)([&](auto output, auto scales) {
par_for(output_shape.elements(), [&](auto i) {
output[i] = static_cast<double>(static_cast<int64_t>(input[i]) -
static_cast<int64_t>(zero_pts[i])) *
scales[i];
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_QUANTIZE_LINEAR_HPP
#define MIGRAPHX_GUARD_OPERATORS_QUANTIZE_LINEAR_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/config.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/tune_axis.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct quantizelinear
{
std::string name() const { return "quantizelinear"; }
shape compute_shape(std::vector<shape> inputs) const
{
if(inputs.size() == 3)
{
return {inputs[2].type(), inputs[0].lens(), inputs[0].strides()};
}
return {shape::uint8_type, inputs[0].lens(), inputs[0].strides()};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
auto x = args.at(0);
auto y_scale = args.at(1);
std::vector<int8_t> zeros(output_shape.elements(), 0);
argument y_zero_point{output_shape, zeros.data()};
if(args.size() == 3)
{
y_zero_point = args.at(2);
}
argument result{output_shape};
visit_all(result, y_zero_point)([&](auto output, auto zero_pts) {
x.visit([&](auto input) {
y_scale.visit([&](auto scales) {
using quant_type = typename decltype(output)::value_type;
auto min_value = std::numeric_limits<quant_type>::min();
auto max_value = std::numeric_limits<quant_type>::max();
par_for(output_shape.elements(), [&](auto i) {
int64_t quantized = static_cast<int64_t>(std::round(input[i] / scales[i])) +
static_cast<int64_t>(zero_pts[i]);
output[i] = std::max(static_cast<int64_t>(min_value),
std::min(static_cast<int64_t>(max_value), quantized));
});
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_QUANTIZATION_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_QUANTIZATION_HPP
#include <string>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
* Rewrite quantization ops to equivalent operators
*/
struct rewrite_quantization
{
std::string name() const { return "rewrite_quantization"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -15,46 +15,49 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear>
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
const std::vector<instruction_ref>& args) const
{
int axis = 1;
if(contains(info.attributes, "axis"))
axis = info.attributes.at("axis").i();
auto input_lens = args[0]->get_shape().lens();
int n_dim = static_cast<int>(input_lens.size());
auto n_dim = input_lens.size();
auto sub_zero_point = args[0];
instruction_ref x_scale;
if(args[1]->get_shape().elements() != 1)
{
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
x_scale = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}), args[1]);
}
else
{
x_scale = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
args[1]);
}
if(args.size() == 3)
{
auto zero_point = args[2];
if(not(zero_point->get_shape().elements() == 1))
auto x_zero_point = args[2];
if(x_zero_point->get_shape().elements() != 1)
{
axis = tune_axis(n_dim, axis, opd.op_name);
zero_point = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), zero_point);
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
x_zero_point = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}),
x_zero_point);
}
auto zero_point_int32 = info.add_instruction(
make_op("convert", {{"target_type", shape::int32_type}}), zero_point);
auto sub_zero_point_int32 = info.add_instruction(
make_op("convert", {{"target_type", shape::int32_type}}), sub_zero_point);
sub_zero_point =
info.add_broadcastable_binary_op("sub", sub_zero_point_int32, zero_point_int32);
else
{
x_zero_point = info.add_instruction(
make_op("multibroadcast", {{"output_lens", input_lens}}), x_zero_point);
}
auto dequant_input = info.add_instruction(
make_op("convert", {{"target_type", shape::float_type}}), sub_zero_point);
auto scale = args[1];
if(not(scale->get_shape().elements() == 1))
{
axis = tune_axis(n_dim, axis, opd.op_name);
scale = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), scale);
return info.add_instruction(
make_op("dequantizelinear"), args[0], x_scale, x_zero_point);
}
return info.add_broadcastable_binary_op("mul", dequant_input, scale);
return info.add_instruction(make_op("dequantizelinear"), args[0], x_scale);
}
};
......
......@@ -12,83 +12,51 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
{
std::vector<op_desc> operators() const { return {{"QuantizeLinear"}}; }
// y = saturate(round(x / y_scale) + zero_point)
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
const std::vector<instruction_ref>& args) const
{
auto quant_type = shape::uint8_type;
int nargs = args.size();
int max_quant = 255;
int min_quant = 0;
if(nargs == 3)
quant_type = args[2]->get_shape().type();
if(quant_type == shape::int8_type)
{
max_quant = 127;
min_quant = -128;
}
auto max_arg = info.add_literal(max_quant);
auto min_arg = info.add_literal(min_quant);
int axis = 1;
if(contains(info.attributes, "axis"))
axis = info.attributes.at("axis").i();
auto input_lens = args[0]->get_shape().lens();
int n_dim = static_cast<int>(input_lens.size());
auto n_dim = input_lens.size();
auto scale = args[1];
if(not(scale->get_shape().elements() == 1))
instruction_ref y_scale;
if(args[1]->get_shape().elements() != 1)
{
axis = tune_axis(n_dim, axis, opd.op_name);
scale = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), scale);
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
y_scale = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}), args[1]);
}
auto div = info.add_broadcastable_binary_op("div", args[0], scale);
auto div_round = info.add_instruction(make_op("round"), div);
auto add_zero_point = div_round;
if(nargs == 3)
else
{
auto zero_point = args[2];
if(not(zero_point->get_shape().elements() == 1))
{
axis = tune_axis(n_dim, axis, opd.op_name);
zero_point = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), zero_point);
}
zero_point = info.add_instruction(
make_op("convert", {{"target_type", shape::int32_type}}), zero_point);
add_zero_point = info.add_instruction(
make_op("convert", {{"target_type", shape::int32_type}}), add_zero_point);
add_zero_point = info.add_broadcastable_binary_op("add", add_zero_point, zero_point);
y_scale = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
args[1]);
}
auto s = add_zero_point->get_shape();
const auto& lens = s.lens();
std::vector<int64_t> out_lens(lens.begin(), lens.end());
if(min_arg->get_shape() != s)
if(args.size() == 3)
{
auto y_zero_point = args[2];
if(y_zero_point->get_shape().elements() != 1)
{
min_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}),
min_arg);
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
y_zero_point = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}),
y_zero_point);
}
if(max_arg->get_shape() != s)
else
{
max_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}),
max_arg);
y_zero_point = info.add_instruction(
make_op("multibroadcast", {{"output_lens", input_lens}}), y_zero_point);
}
return info.add_instruction(make_op("quantizelinear"), args[0], y_scale, y_zero_point);
}
auto saturated = info.add_instruction(make_op("clip"), add_zero_point, min_arg, max_arg);
return info.add_instruction(make_op("convert", {{"target_type", quant_type}}), saturated);
return info.add_instruction(make_op("quantizelinear"), args[0], y_scale);
}
};
......
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void apply_quantizelinear(module& m, instruction_ref ins)
{
assert(ins->name() == "quantizelinear");
auto x = ins->inputs()[0];
auto y_scale = ins->inputs()[1];
if(x->get_shape().type() != y_scale->get_shape().type())
{
x = m.insert_instruction(ins, make_op("convert", {{"target_type", shape::float_type}}), x);
}
auto div = m.insert_instruction(ins, make_op("div"), x, y_scale);
auto add_zero_point = m.insert_instruction(ins, make_op("round"), div);
if(ins->inputs().size() == 3)
{
auto zero_point = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[2]);
add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point);
}
int64_t max_quant = 0;
int64_t min_quant = 0;
ins->get_shape().visit_type([&](auto qt) {
max_quant = qt.max();
min_quant = qt.min();
});
auto s = add_zero_point->get_shape();
std::vector<int> min_data(s.elements(), min_quant);
std::vector<int> max_data(s.elements(), max_quant);
auto min_arg = m.add_literal(literal(s, min_data));
auto max_arg = m.add_literal(literal(s, max_data));
auto saturate = m.insert_instruction(ins, make_op("clip"), add_zero_point, min_arg, max_arg);
m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate);
}
void apply_dequantizelinear(module& m, instruction_ref ins)
{
assert(ins->name() == "dequantizelinear");
auto x = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[0]);
auto x_scale = ins->inputs()[1];
if(ins->inputs().size() == 3)
{
auto x_zero_point = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[2]);
x = m.insert_instruction(ins, make_op("sub"), x, x_zero_point);
}
m.replace_instruction(ins, make_op("mul"), x, x_scale);
}
void rewrite_quantization::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
if(ins->name() == "quantizelinear")
{
apply_quantizelinear(m, ins);
}
else if(ins->name() == "dequantizelinear")
{
apply_dequantizelinear(m, ins);
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -17,6 +17,7 @@
#include <migraphx/remap.hpp>
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/schedule.hpp>
#include <migraphx/memory_coloring.hpp>
......@@ -46,6 +47,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end());
unsupported_types.erase(shape::type_t::float_type);
return {normalize_ops{},
rewrite_quantization{},
dead_code_elimination{},
eliminate_data_type{unsupported_types, shape::type_t::float_type},
dead_code_elimination{},
decompose{},
......
......@@ -20,6 +20,7 @@
#include <migraphx/remap.hpp>
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/schedule.hpp>
#include <migraphx/simplify_algebra.hpp>
......@@ -59,6 +60,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
normalize_ops{},
decompose{},
dead_code_elimination{},
rewrite_quantization{},
dead_code_elimination{},
eliminate_data_type{unsupported_types, shape::type_t::float_type},
simplify_reshapes{},
eliminate_identity{},
......
......@@ -23,4 +23,4 @@



B
\ No newline at end of file
B
\ No newline at end of file
......@@ -23,4 +23,4 @@



B
\ No newline at end of file
B
\ No newline at end of file
dequantizelinear_test:
dequantizelinear_test:k

0
1
2out"DequantizeLineardequantizelinear_testZ
1out"DequantizeLineardequantizelinear_testZ
0

......@@ -10,12 +9,8 @@
1

Z
2

b
out

B
\ No newline at end of file
B
\ No newline at end of file
 dequantizelinear_zero_point_test:
0
1
2out"DequantizeLinear dequantizelinear_zero_point_testZ
0

Z
1

Z
2

b
out

B
\ No newline at end of file
......@@ -1021,6 +1021,21 @@ def deconv_stride_test():
@onnx_test
def dequantizelinear_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.INT8, [5])
arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1])
arg_out = helper.make_tensor_value_info('out', TensorProto.FLOAT, [5])
node = onnx.helper.make_node(
'DequantizeLinear',
inputs=['0', '1'],
outputs=['out'],
)
return ([node], [arg0, arg1], [arg_out])
@onnx_test
def dequantizelinear_zero_point_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.INT8, [5])
arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1])
arg2 = helper.make_tensor_value_info('2', TensorProto.INT8, [1])
......@@ -2751,6 +2766,36 @@ def prelu_brcst_test():
@onnx_test
def quantizelinear_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5])
arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1])
arg_out = helper.make_tensor_value_info('out', TensorProto.INT8, [5])
node = onnx.helper.make_node(
'QuantizeLinear',
inputs=['0', '1'],
outputs=['out'],
)
return ([node], [arg0, arg1], [arg_out])
@onnx_test
def quantizelinear_int32_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.INT32, [5])
arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1])
arg_out = helper.make_tensor_value_info('out', TensorProto.INT8, [5])
node = onnx.helper.make_node(
'QuantizeLinear',
inputs=['0', '1'],
outputs=['out'],
)
return ([node], [arg0, arg1], [arg_out])
@onnx_test
def quantizelinear_zero_point_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5])
arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1])
arg2 = helper.make_tensor_value_info('2', TensorProto.INT8, [1])
......
......@@ -2,28 +2,35 @@
#include <fstream>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/serialize.hpp>
#include "test.hpp"
migraphx::program optimize_onnx(const std::string& name, bool eliminate_deadcode = false)
migraphx::program optimize_onnx(const std::string& name, bool run_passes = false)
{
migraphx::onnx_options options;
options.skip_unknown_operators = true;
auto prog = migraphx::parse_onnx(name, options);
auto* mm = prog.get_main_module();
if(eliminate_deadcode)
migraphx::run_passes(*mm, {migraphx::dead_code_elimination{}});
if(run_passes)
migraphx::run_passes(*mm,
{migraphx::rewrite_quantization{}, migraphx::dead_code_elimination{}});
// remove the last identity instruction
auto last_ins = std::prev(mm->end());
......@@ -914,29 +921,42 @@ TEST_CASE(dequantizelinear_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::int8_type, {5}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {1}});
auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1);
l2 = mm->add_instruction(
auto dequant = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
l2);
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l0);
mm->add_instruction(migraphx::make_op("mul"), dequant, l1_mbcast);
auto prog = optimize_onnx("dequantizelinear_test.onnx", true);
EXPECT(p.sort() == prog.sort());
}
TEST_CASE(dequantizelinear_zero_point_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::int8_type, {5}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {1}});
auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1);
auto l2_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l2);
l2_mbcast = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l2_mbcast);
l0 = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l0);
auto sub = mm->add_instruction(migraphx::make_op("sub"), l0, l2_mbcast);
auto dequant = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
sub);
mm->add_instruction(migraphx::make_op("mul"), sub, l1_mbcast);
mm->add_instruction(migraphx::make_op("mul"), dequant, l1_mbcast);
auto prog = optimize_onnx("dequantizelinear_test.onnx");
auto prog = optimize_onnx("dequantizelinear_zero_point_test.onnx", true);
EXPECT(p.sort() == prog.sort());
}
......@@ -955,19 +975,15 @@ migraphx::program make_dequantizelinear_axis_prog()
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), l2);
l2_bcast = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l2_bcast);
l0 = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l0);
auto sub = mm->add_instruction(migraphx::make_op("sub"), l0, l2_bcast);
auto dequant = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
sub);
mm->add_instruction(migraphx::make_op("mul"), dequant, l1_bcast);
mm->add_instruction(migraphx::make_op("mul"), sub, l1_bcast);
return p;
}
......@@ -975,7 +991,7 @@ TEST_CASE(dequantizelinear_axis_test)
{
migraphx::program p = make_dequantizelinear_axis_prog();
auto prog = optimize_onnx("dequantizelinear_axis_test.onnx");
auto prog = optimize_onnx("dequantizelinear_axis_test.onnx", true);
EXPECT(p.sort() == prog.sort());
}
......@@ -983,7 +999,7 @@ TEST_CASE(dequantizelinear_neg_axis_test)
{
migraphx::program p = make_dequantizelinear_axis_prog();
auto prog = optimize_onnx("dequantizelinear_neg_axis_test.onnx");
auto prog = optimize_onnx("dequantizelinear_neg_axis_test.onnx", true);
EXPECT(p.sort() == prog.sort());
}
......@@ -2381,38 +2397,84 @@ TEST_CASE(quantizelinear_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {5}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {1}});
auto min_val = mm->add_literal(-128);
auto max_val = mm->add_literal(127);
auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div);
auto s = round->get_shape();
std::vector<int> min_data(s.elements(), 0);
std::vector<int> max_data(s.elements(), 255);
auto min_arg = mm->add_literal(s, min_data);
auto max_arg = mm->add_literal(s, max_data);
auto clip = mm->add_instruction(migraphx::make_op("clip"), round, min_arg, max_arg);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::uint8_type)}}),
clip);
auto prog = optimize_onnx("quantizelinear_test.onnx", true);
EXPECT(p.sort() == prog.sort());
}
TEST_CASE(quantizelinear_int32_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::int32_type, {5}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1);
l0 = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l0);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div);
l2 = mm->add_instruction(
auto s = round->get_shape();
std::vector<int> min_data(s.elements(), 0);
std::vector<int> max_data(s.elements(), 255);
auto min_arg = mm->add_literal(s, min_data);
auto max_arg = mm->add_literal(s, max_data);
auto clip = mm->add_instruction(migraphx::make_op("clip"), round, min_arg, max_arg);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
l2);
{{"target_type", migraphx::to_value(migraphx::shape::uint8_type)}}),
clip);
auto prog = optimize_onnx("quantizelinear_int32_test.onnx", true);
EXPECT(p.sort() == prog.sort());
}
TEST_CASE(quantizelinear_zero_point_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {5}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {1}});
auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div);
auto l2_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l2);
round = mm->add_instruction(
l2_mbcast = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
round);
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l2_mbcast);
auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_mbcast);
min_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), min_val);
max_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), max_val);
auto clip = mm->add_instruction(migraphx::make_op("clip"), add, min_val, max_val);
auto s = round->get_shape();
std::vector<int> min_data(s.elements(), -128);
std::vector<int> max_data(s.elements(), 127);
auto min_arg = mm->add_literal(s, min_data);
auto max_arg = mm->add_literal(s, max_data);
auto clip = mm->add_instruction(migraphx::make_op("clip"), add, min_arg, max_arg);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
clip);
auto prog = optimize_onnx("quantizelinear_test.onnx");
auto prog = optimize_onnx("quantizelinear_zero_point_test.onnx", true);
EXPECT(p.sort() == prog.sort());
}
......@@ -2426,9 +2488,6 @@ migraphx::program make_quantizelinear_axis_prog()
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, input_lens});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {5}});
auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {5}});
auto min_val = mm->add_literal(-128);
auto max_val = mm->add_literal(127);
auto l1_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), l1);
......@@ -2438,18 +2497,15 @@ migraphx::program make_quantizelinear_axis_prog()
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), l2);
l2_bcast = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l2_bcast);
round = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
round);
auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_bcast);
min_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 1, 5, 1}}}), min_val);
max_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 1, 5, 1}}}), max_val);
auto clip = mm->add_instruction(migraphx::make_op("clip"), add, min_val, max_val);
auto s = round->get_shape();
std::vector<int> min_data(s.elements(), -128);
std::vector<int> max_data(s.elements(), 127);
auto min_arg = mm->add_literal(s, min_data);
auto max_arg = mm->add_literal(s, max_data);
auto clip = mm->add_instruction(migraphx::make_op("clip"), add, min_arg, max_arg);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
......@@ -2461,7 +2517,7 @@ TEST_CASE(quantizelinear_axis_test)
{
migraphx::program p = make_quantizelinear_axis_prog();
auto prog = optimize_onnx("quantizelinear_axis_test.onnx");
auto prog = optimize_onnx("quantizelinear_axis_test.onnx", true);
EXPECT(p.sort() == prog.sort());
}
......@@ -2469,7 +2525,7 @@ TEST_CASE(quantizelinear_neg_axis_test)
{
migraphx::program p = make_quantizelinear_axis_prog();
auto prog = optimize_onnx("quantizelinear_neg_axis_test.onnx");
auto prog = optimize_onnx("quantizelinear_neg_axis_test.onnx", true);
EXPECT(p.sort() == prog.sort());
}
......
......@@ -23,4 +23,4 @@



B
\ No newline at end of file
B
\ No newline at end of file
quantizelinear_int32_test:m

0
1out"QuantizeLinearquantizelinear_int32_testZ
0

Z
1

b
out

B
\ No newline at end of file
......@@ -23,4 +23,4 @@



B
\ No newline at end of file
B
\ No newline at end of file
quantizelinear_test:{

quantizelinear_test:g

0
1
2out"QuantizeLinearquantizelinear_testZ
1out"QuantizeLinearquantizelinear_testZ
0

......@@ -10,12 +9,8 @@
1

Z
2

b
out

B
\ No newline at end of file
B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment