Unverified Commit ad4a68f0 authored by Cagri Eryilmaz's avatar Cagri Eryilmaz Committed by GitHub
Browse files

Merge branch 'develop' into unet

parents 7555a27a eacf042e
...@@ -53,6 +53,7 @@ add_library(migraphx ...@@ -53,6 +53,7 @@ add_library(migraphx
remap.cpp remap.cpp
rewrite_batchnorm.cpp rewrite_batchnorm.cpp
rewrite_pooling.cpp rewrite_pooling.cpp
rewrite_quantization.cpp
rewrite_rnn.cpp rewrite_rnn.cpp
schedule.cpp schedule.cpp
serialize.cpp serialize.cpp
...@@ -94,6 +95,7 @@ register_migraphx_ops( ...@@ -94,6 +95,7 @@ register_migraphx_ops(
cosh cosh
cos cos
deconvolution deconvolution
dequantizelinear
div div
dot dot
elu elu
...@@ -132,6 +134,7 @@ register_migraphx_ops( ...@@ -132,6 +134,7 @@ register_migraphx_ops(
prelu prelu
quant_convolution quant_convolution
quant_dot quant_dot
quantizelinear
recip recip
reduce_max reduce_max
reduce_mean reduce_mean
......
#ifndef MIGRAPHX_GUARD_OPERATORS_DEQUANTIZE_LINEAR_HPP
#define MIGRAPHX_GUARD_OPERATORS_DEQUANTIZE_LINEAR_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/config.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/tune_axis.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct dequantizelinear
{
std::string name() const { return "dequantizelinear"; }
shape compute_shape(std::vector<shape> inputs) const
{
return {shape::float_type, inputs[0].lens(), inputs[0].strides()};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
auto x = args.at(0);
auto x_scale = args.at(1);
std::vector<int8_t> zeros(output_shape.elements(), 0);
argument x_zero_point{{x.get_shape().type(), output_shape.lens()}, zeros.data()};
if(args.size() == 3)
{
x_zero_point = args.at(2);
}
argument result{output_shape};
visit_all(x, x_zero_point)([&](auto input, auto zero_pts) {
visit_all(result, x_scale)([&](auto output, auto scales) {
par_for(output_shape.elements(), [&](auto i) {
output[i] = static_cast<double>(static_cast<int64_t>(input[i]) -
static_cast<int64_t>(zero_pts[i])) *
scales[i];
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_QUANTIZE_LINEAR_HPP
#define MIGRAPHX_GUARD_OPERATORS_QUANTIZE_LINEAR_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/config.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/tune_axis.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct quantizelinear
{
std::string name() const { return "quantizelinear"; }
shape compute_shape(std::vector<shape> inputs) const
{
if(inputs.size() == 3)
{
return {inputs[2].type(), inputs[0].lens(), inputs[0].strides()};
}
return {shape::uint8_type, inputs[0].lens(), inputs[0].strides()};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
auto x = args.at(0);
auto y_scale = args.at(1);
std::vector<int8_t> zeros(output_shape.elements(), 0);
argument y_zero_point{output_shape, zeros.data()};
if(args.size() == 3)
{
y_zero_point = args.at(2);
}
argument result{output_shape};
visit_all(result, y_zero_point)([&](auto output, auto zero_pts) {
x.visit([&](auto input) {
y_scale.visit([&](auto scales) {
using quant_type = typename decltype(output)::value_type;
auto min_value = std::numeric_limits<quant_type>::min();
auto max_value = std::numeric_limits<quant_type>::max();
par_for(output_shape.elements(), [&](auto i) {
int64_t quantized = static_cast<int64_t>(std::round(input[i] / scales[i])) +
static_cast<int64_t>(zero_pts[i]);
output[i] = std::max(static_cast<int64_t>(min_value),
std::min(static_cast<int64_t>(max_value), quantized));
});
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_QUANTIZATION_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_QUANTIZATION_HPP
#include <string>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
* Rewrite quantization ops to equivalent operators
*/
struct rewrite_quantization
{
std::string name() const { return "rewrite_quantization"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -15,46 +15,49 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear> ...@@ -15,46 +15,49 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear>
instruction_ref parse(const op_desc& opd, instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/, const onnx_parser& /*parser*/,
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const const std::vector<instruction_ref>& args) const
{ {
int axis = 1; int axis = 1;
if(contains(info.attributes, "axis")) if(contains(info.attributes, "axis"))
axis = info.attributes.at("axis").i(); axis = info.attributes.at("axis").i();
auto input_lens = args[0]->get_shape().lens(); auto input_lens = args[0]->get_shape().lens();
int n_dim = static_cast<int>(input_lens.size()); auto n_dim = input_lens.size();
auto sub_zero_point = args[0]; instruction_ref x_scale;
if(args[1]->get_shape().elements() != 1)
{
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
x_scale = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}), args[1]);
}
else
{
x_scale = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
args[1]);
}
if(args.size() == 3) if(args.size() == 3)
{ {
auto zero_point = args[2]; auto x_zero_point = args[2];
if(not(zero_point->get_shape().elements() == 1)) if(x_zero_point->get_shape().elements() != 1)
{ {
axis = tune_axis(n_dim, axis, opd.op_name); auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
zero_point = info.add_instruction( x_zero_point = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), zero_point); make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}),
x_zero_point);
} }
else
auto zero_point_int32 = info.add_instruction( {
make_op("convert", {{"target_type", shape::int32_type}}), zero_point); x_zero_point = info.add_instruction(
auto sub_zero_point_int32 = info.add_instruction( make_op("multibroadcast", {{"output_lens", input_lens}}), x_zero_point);
make_op("convert", {{"target_type", shape::int32_type}}), sub_zero_point);
sub_zero_point =
info.add_broadcastable_binary_op("sub", sub_zero_point_int32, zero_point_int32);
} }
auto dequant_input = info.add_instruction( return info.add_instruction(
make_op("convert", {{"target_type", shape::float_type}}), sub_zero_point); make_op("dequantizelinear"), args[0], x_scale, x_zero_point);
auto scale = args[1];
if(not(scale->get_shape().elements() == 1))
{
axis = tune_axis(n_dim, axis, opd.op_name);
scale = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), scale);
} }
return info.add_broadcastable_binary_op("mul", dequant_input, scale);
return info.add_instruction(make_op("dequantizelinear"), args[0], x_scale);
} }
}; };
......
...@@ -42,13 +42,23 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -42,13 +42,23 @@ struct parse_gemm : op_parser<parse_gemm>
// swap the last two elements // swap the last two elements
std::swap(*perm.rbegin(), *(perm.rbegin() + 1)); std::swap(*perm.rbegin(), *(perm.rbegin() + 1));
auto l1 = (transa) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[0]) auto l1 = args[0];
: args[0];
if(alpha != 1.0f)
{
auto alpha_literal = info.add_literal(alpha);
auto alpha_l1 = info.add_broadcastable_binary_op("mul", alpha_literal, l1);
l1 = info.add_instruction(make_op("convert", {{"target_type", l1->get_shape().type()}}),
alpha_l1);
}
l1 = (transa) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), l1) : l1;
auto l2 = (transb) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[1]) auto l2 = (transb) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[1])
: args[1]; : args[1];
if(args.size() == 3) if(args.size() == 3)
{ {
if(beta != 0.f && args[2]->get_shape().elements() > 0) if(beta != 0.0f && args[2]->get_shape().elements() > 0)
{ {
auto out_lens = l1->get_shape().lens(); auto out_lens = l1->get_shape().lens();
out_lens.back() = l2->get_shape().lens().back(); out_lens.back() = l2->get_shape().lens().back();
...@@ -59,12 +69,17 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -59,12 +69,17 @@ struct parse_gemm : op_parser<parse_gemm>
l3 = info.add_instruction( l3 = info.add_instruction(
make_op("multibroadcast", {{"output_lens", out_lens}}), args[2]); make_op("multibroadcast", {{"output_lens", out_lens}}), args[2]);
} }
auto beta_literal = info.add_literal(beta);
auto beta_broadcast = info.add_instruction(
make_op("multibroadcast", {{"output_lens", out_lens}}), beta_literal);
l3 = info.add_instruction(make_op("mul"), l3, beta_broadcast);
return info.add_instruction( return info.add_instruction(
make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3); make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), l1, l2, l3);
} }
} }
return info.add_instruction(make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2); return info.add_instruction(make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), l1, l2);
} }
}; };
......
...@@ -12,83 +12,51 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear> ...@@ -12,83 +12,51 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
{ {
std::vector<op_desc> operators() const { return {{"QuantizeLinear"}}; } std::vector<op_desc> operators() const { return {{"QuantizeLinear"}}; }
// y = saturate(round(x / y_scale) + zero_point)
instruction_ref parse(const op_desc& opd, instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/, const onnx_parser& /*parser*/,
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const const std::vector<instruction_ref>& args) const
{ {
auto quant_type = shape::uint8_type;
int nargs = args.size();
int max_quant = 255;
int min_quant = 0;
if(nargs == 3)
quant_type = args[2]->get_shape().type();
if(quant_type == shape::int8_type)
{
max_quant = 127;
min_quant = -128;
}
auto max_arg = info.add_literal(max_quant);
auto min_arg = info.add_literal(min_quant);
int axis = 1; int axis = 1;
if(contains(info.attributes, "axis")) if(contains(info.attributes, "axis"))
axis = info.attributes.at("axis").i(); axis = info.attributes.at("axis").i();
auto input_lens = args[0]->get_shape().lens(); auto input_lens = args[0]->get_shape().lens();
int n_dim = static_cast<int>(input_lens.size()); auto n_dim = input_lens.size();
auto scale = args[1]; instruction_ref y_scale;
if(not(scale->get_shape().elements() == 1)) if(args[1]->get_shape().elements() != 1)
{ {
axis = tune_axis(n_dim, axis, opd.op_name); auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
scale = info.add_instruction( y_scale = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), scale); make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}), args[1]);
} }
else
auto div = info.add_broadcastable_binary_op("div", args[0], scale);
auto div_round = info.add_instruction(make_op("round"), div);
auto add_zero_point = div_round;
if(nargs == 3)
{ {
auto zero_point = args[2]; y_scale = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
if(not(zero_point->get_shape().elements() == 1)) args[1]);
{
axis = tune_axis(n_dim, axis, opd.op_name);
zero_point = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), zero_point);
}
zero_point = info.add_instruction(
make_op("convert", {{"target_type", shape::int32_type}}), zero_point);
add_zero_point = info.add_instruction(
make_op("convert", {{"target_type", shape::int32_type}}), add_zero_point);
add_zero_point = info.add_broadcastable_binary_op("add", add_zero_point, zero_point);
} }
auto s = add_zero_point->get_shape(); if(args.size() == 3)
const auto& lens = s.lens(); {
std::vector<int64_t> out_lens(lens.begin(), lens.end()); auto y_zero_point = args[2];
if(min_arg->get_shape() != s) if(y_zero_point->get_shape().elements() != 1)
{ {
min_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}), auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
min_arg); y_zero_point = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}),
y_zero_point);
} }
if(max_arg->get_shape() != s) else
{ {
max_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}), y_zero_point = info.add_instruction(
max_arg); make_op("multibroadcast", {{"output_lens", input_lens}}), y_zero_point);
}
return info.add_instruction(make_op("quantizelinear"), args[0], y_scale, y_zero_point);
} }
auto saturated = info.add_instruction(make_op("clip"), add_zero_point, min_arg, max_arg); return info.add_instruction(make_op("quantizelinear"), args[0], y_scale);
return info.add_instruction(make_op("convert", {{"target_type", quant_type}}), saturated);
} }
}; };
......
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void apply_quantizelinear(module& m, instruction_ref ins)
{
assert(ins->name() == "quantizelinear");
auto x = ins->inputs()[0];
auto y_scale = ins->inputs()[1];
if(x->get_shape().type() != y_scale->get_shape().type())
{
x = m.insert_instruction(ins, make_op("convert", {{"target_type", shape::float_type}}), x);
}
auto div = m.insert_instruction(ins, make_op("div"), x, y_scale);
auto add_zero_point = m.insert_instruction(ins, make_op("round"), div);
if(ins->inputs().size() == 3)
{
auto zero_point = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[2]);
add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point);
}
int64_t max_quant = 0;
int64_t min_quant = 0;
ins->get_shape().visit_type([&](auto qt) {
max_quant = qt.max();
min_quant = qt.min();
});
auto s = add_zero_point->get_shape();
std::vector<int> min_data(s.elements(), min_quant);
std::vector<int> max_data(s.elements(), max_quant);
auto min_arg = m.add_literal(literal(s, min_data));
auto max_arg = m.add_literal(literal(s, max_data));
auto saturate = m.insert_instruction(ins, make_op("clip"), add_zero_point, min_arg, max_arg);
m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate);
}
void apply_dequantizelinear(module& m, instruction_ref ins)
{
assert(ins->name() == "dequantizelinear");
auto x = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[0]);
auto x_scale = ins->inputs()[1];
if(ins->inputs().size() == 3)
{
auto x_zero_point = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[2]);
x = m.insert_instruction(ins, make_op("sub"), x, x_zero_point);
}
m.replace_instruction(ins, make_op("mul"), x, x_scale);
}
void rewrite_quantization::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
if(ins->name() == "quantizelinear")
{
apply_quantizelinear(m, ins);
}
else if(ins->name() == "dequantizelinear")
{
apply_dequantizelinear(m, ins);
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <migraphx/remap.hpp> #include <migraphx/remap.hpp>
#include <migraphx/rewrite_batchnorm.hpp> #include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_pooling.hpp> #include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
#include <migraphx/schedule.hpp> #include <migraphx/schedule.hpp>
#include <migraphx/memory_coloring.hpp> #include <migraphx/memory_coloring.hpp>
...@@ -46,6 +47,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -46,6 +47,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end()); std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end());
unsupported_types.erase(shape::type_t::float_type); unsupported_types.erase(shape::type_t::float_type);
return {normalize_ops{}, return {normalize_ops{},
rewrite_quantization{},
dead_code_elimination{},
eliminate_data_type{unsupported_types, shape::type_t::float_type}, eliminate_data_type{unsupported_types, shape::type_t::float_type},
dead_code_elimination{}, dead_code_elimination{},
decompose{}, decompose{},
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <migraphx/remap.hpp> #include <migraphx/remap.hpp>
#include <migraphx/rewrite_batchnorm.hpp> #include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_pooling.hpp> #include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
#include <migraphx/schedule.hpp> #include <migraphx/schedule.hpp>
#include <migraphx/simplify_algebra.hpp> #include <migraphx/simplify_algebra.hpp>
...@@ -59,6 +60,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -59,6 +60,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
normalize_ops{}, normalize_ops{},
decompose{}, decompose{},
dead_code_elimination{}, dead_code_elimination{},
rewrite_quantization{},
dead_code_elimination{},
eliminate_data_type{unsupported_types, shape::type_t::float_type}, eliminate_data_type{unsupported_types, shape::type_t::float_type},
simplify_reshapes{}, simplify_reshapes{},
eliminate_identity{}, eliminate_identity{},
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <type_traits>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -101,7 +102,9 @@ inline Stream& operator<<(Stream& s, std::nullptr_t) ...@@ -101,7 +102,9 @@ inline Stream& operator<<(Stream& s, std::nullptr_t)
return s; return s;
} }
template <class Stream, class Range> template <class Stream,
class Range,
class = typename std::enable_if<not std::is_convertible<Range, std::string>{}>::type>
inline auto operator<<(Stream& s, const Range& v) -> decltype(stream_range(s, v.begin(), v.end())) inline auto operator<<(Stream& s, const Range& v) -> decltype(stream_range(s, v.begin(), v.end()))
{ {
s << "{ "; s << "{ ";
......
...@@ -23,4 +23,4 @@ ...@@ -23,4 +23,4 @@
 
 
 
B B
\ No newline at end of file \ No newline at end of file
...@@ -23,4 +23,4 @@ ...@@ -23,4 +23,4 @@
 
 
 
B B
\ No newline at end of file \ No newline at end of file
dequantizelinear_test: dequantizelinear_test:k

0 0
1 1out"DequantizeLineardequantizelinear_testZ
2out"DequantizeLineardequantizelinear_testZ
0 0
 
...@@ -10,12 +9,8 @@ ...@@ -10,12 +9,8 @@
1 1
 
Z
2

b b
out out
 
B B
\ No newline at end of file \ No newline at end of file
 dequantizelinear_zero_point_test:
0
1
2out"DequantizeLinear dequantizelinear_zero_point_testZ
0

Z
1

Z
2

b
out

B
\ No newline at end of file
...@@ -1021,6 +1021,21 @@ def deconv_stride_test(): ...@@ -1021,6 +1021,21 @@ def deconv_stride_test():
@onnx_test @onnx_test
def dequantizelinear_test(): def dequantizelinear_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.INT8, [5])
arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1])
arg_out = helper.make_tensor_value_info('out', TensorProto.FLOAT, [5])
node = onnx.helper.make_node(
'DequantizeLinear',
inputs=['0', '1'],
outputs=['out'],
)
return ([node], [arg0, arg1], [arg_out])
@onnx_test
def dequantizelinear_zero_point_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.INT8, [5]) arg0 = helper.make_tensor_value_info('0', TensorProto.INT8, [5])
arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1]) arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1])
arg2 = helper.make_tensor_value_info('2', TensorProto.INT8, [1]) arg2 = helper.make_tensor_value_info('2', TensorProto.INT8, [1])
...@@ -2751,6 +2766,36 @@ def prelu_brcst_test(): ...@@ -2751,6 +2766,36 @@ def prelu_brcst_test():
@onnx_test @onnx_test
def quantizelinear_test(): def quantizelinear_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5])
arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1])
arg_out = helper.make_tensor_value_info('out', TensorProto.INT8, [5])
node = onnx.helper.make_node(
'QuantizeLinear',
inputs=['0', '1'],
outputs=['out'],
)
return ([node], [arg0, arg1], [arg_out])
@onnx_test
def quantizelinear_int32_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.INT32, [5])
arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1])
arg_out = helper.make_tensor_value_info('out', TensorProto.INT8, [5])
node = onnx.helper.make_node(
'QuantizeLinear',
inputs=['0', '1'],
outputs=['out'],
)
return ([node], [arg0, arg1], [arg_out])
@onnx_test
def quantizelinear_zero_point_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5]) arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5])
arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1]) arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1])
arg2 = helper.make_tensor_value_info('2', TensorProto.INT8, [1]) arg2 = helper.make_tensor_value_info('2', TensorProto.INT8, [1])
......
#include <iostream> #include <iostream>
#include <fstream> #include <fstream>
#include <vector> #include <vector>
#include <migraphx/common.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/eliminate_identity.hpp> #include <migraphx/eliminate_identity.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include "test.hpp" #include "test.hpp"
migraphx::program optimize_onnx(const std::string& name, bool eliminate_deadcode = false) migraphx::program optimize_onnx(const std::string& name, bool run_passes = false)
{ {
migraphx::onnx_options options; migraphx::onnx_options options;
options.skip_unknown_operators = true; options.skip_unknown_operators = true;
auto prog = migraphx::parse_onnx(name, options); auto prog = migraphx::parse_onnx(name, options);
auto* mm = prog.get_main_module(); auto* mm = prog.get_main_module();
if(eliminate_deadcode) if(run_passes)
migraphx::run_passes(*mm, {migraphx::dead_code_elimination{}}); migraphx::run_passes(*mm,
{migraphx::rewrite_quantization{}, migraphx::dead_code_elimination{}});
// remove the last identity instruction // remove the last identity instruction
auto last_ins = std::prev(mm->end()); auto last_ins = std::prev(mm->end());
...@@ -914,29 +922,42 @@ TEST_CASE(dequantizelinear_test) ...@@ -914,29 +922,42 @@ TEST_CASE(dequantizelinear_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::int8_type, {5}}); auto l0 = mm->add_parameter("0", {migraphx::shape::int8_type, {5}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}}); auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {1}});
auto l1_mbcast = auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1);
l2 = mm->add_instruction( auto dequant = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l2); l0);
mm->add_instruction(migraphx::make_op("mul"), dequant, l1_mbcast);
auto prog = optimize_onnx("dequantizelinear_test.onnx", true);
EXPECT(p.sort() == prog.sort());
}
TEST_CASE(dequantizelinear_zero_point_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::int8_type, {5}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {1}});
auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1);
auto l2_mbcast = auto l2_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l2); mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l2);
l2_mbcast = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l2_mbcast);
l0 = mm->add_instruction( l0 = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l0); l0);
auto sub = mm->add_instruction(migraphx::make_op("sub"), l0, l2_mbcast); auto sub = mm->add_instruction(migraphx::make_op("sub"), l0, l2_mbcast);
auto dequant = mm->add_instruction( mm->add_instruction(migraphx::make_op("mul"), sub, l1_mbcast);
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
sub);
mm->add_instruction(migraphx::make_op("mul"), dequant, l1_mbcast);
auto prog = optimize_onnx("dequantizelinear_test.onnx"); auto prog = optimize_onnx("dequantizelinear_zero_point_test.onnx", true);
EXPECT(p.sort() == prog.sort()); EXPECT(p.sort() == prog.sort());
} }
...@@ -955,19 +976,15 @@ migraphx::program make_dequantizelinear_axis_prog() ...@@ -955,19 +976,15 @@ migraphx::program make_dequantizelinear_axis_prog()
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), l2); migraphx::make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), l2);
l2_bcast = mm->add_instruction( l2_bcast = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l2_bcast); l2_bcast);
l0 = mm->add_instruction( l0 = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l0); l0);
auto sub = mm->add_instruction(migraphx::make_op("sub"), l0, l2_bcast); auto sub = mm->add_instruction(migraphx::make_op("sub"), l0, l2_bcast);
auto dequant = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
sub);
mm->add_instruction(migraphx::make_op("mul"), dequant, l1_bcast); mm->add_instruction(migraphx::make_op("mul"), sub, l1_bcast);
return p; return p;
} }
...@@ -975,7 +992,7 @@ TEST_CASE(dequantizelinear_axis_test) ...@@ -975,7 +992,7 @@ TEST_CASE(dequantizelinear_axis_test)
{ {
migraphx::program p = make_dequantizelinear_axis_prog(); migraphx::program p = make_dequantizelinear_axis_prog();
auto prog = optimize_onnx("dequantizelinear_axis_test.onnx"); auto prog = optimize_onnx("dequantizelinear_axis_test.onnx", true);
EXPECT(p.sort() == prog.sort()); EXPECT(p.sort() == prog.sort());
} }
...@@ -983,7 +1000,7 @@ TEST_CASE(dequantizelinear_neg_axis_test) ...@@ -983,7 +1000,7 @@ TEST_CASE(dequantizelinear_neg_axis_test)
{ {
migraphx::program p = make_dequantizelinear_axis_prog(); migraphx::program p = make_dequantizelinear_axis_prog();
auto prog = optimize_onnx("dequantizelinear_neg_axis_test.onnx"); auto prog = optimize_onnx("dequantizelinear_neg_axis_test.onnx", true);
EXPECT(p.sort() == prog.sort()); EXPECT(p.sort() == prog.sort());
} }
...@@ -1256,15 +1273,24 @@ TEST_CASE(gemm_test) ...@@ -1256,15 +1273,24 @@ TEST_CASE(gemm_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 7}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 7}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {11, 5}}); auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {11, 5}});
auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type}); auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type});
auto t0 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l0);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
auto bl2 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {7, 11}}}), l2);
auto alpha = 2.f; auto alpha = 2.f;
auto beta = 2.0f; auto beta = 2.0f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), t0, t1, bl2); auto a_l = mm->add_literal(alpha);
auto prog = optimize_onnx("gemm_test.onnx"); auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", l1->get_shape().type()}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), t_a);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
auto b_l = mm->add_literal(beta);
auto l2_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {7, 11}}}), l2);
auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", l2_b->get_shape().lens()}}), b_l);
auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b);
mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), t_a, t1, l2_bb);
auto prog = optimize_onnx("gemm_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1275,10 +1301,21 @@ TEST_CASE(gemm_ex_test) ...@@ -1275,10 +1301,21 @@ TEST_CASE(gemm_ex_test)
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 8, 6}}); auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 8, 6}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 8, 7}}); auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 8, 7}});
auto l2 = mm->add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 7}}); auto l2 = mm->add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 7}});
auto t0 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l0);
auto alpha = 0.5f; auto alpha = 0.5f;
auto beta = 0.8f; auto beta = 0.8f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), t0, l1, l2); auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", l1->get_shape().type()}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t_a);
auto b_l = mm->add_literal(beta);
auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", l2->get_shape().lens()}}), b_l);
auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), t_a, l1, l2_b);
auto prog = optimize_onnx("gemm_ex_test.onnx"); auto prog = optimize_onnx("gemm_ex_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -1291,13 +1328,25 @@ TEST_CASE(gemm_ex_brcst_test) ...@@ -1291,13 +1328,25 @@ TEST_CASE(gemm_ex_brcst_test)
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 6}}); auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 6}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 7}}); auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 7}});
auto l2 = mm->add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 1}}); auto l2 = mm->add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 1}});
auto t0 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l0);
std::vector<std::size_t> out_lens{1, 1, 6, 7}; std::vector<std::size_t> out_lens{1, 1, 6, 7};
auto t2 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", out_lens}}), l2);
auto alpha = 0.5f; auto alpha = 0.5f;
auto beta = 0.8f; auto beta = 0.8f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), t0, l1, t2); auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", l1->get_shape().type()}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t_a);
auto b_l = mm->add_literal(beta);
auto l2_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", out_lens}}), l2);
auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", l2_b->get_shape().lens()}}), b_l);
auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b);
mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), t_a, l1, l2_bb);
auto prog = optimize_onnx("gemm_ex_brcst_test.onnx"); auto prog = optimize_onnx("gemm_ex_brcst_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -2381,38 +2430,84 @@ TEST_CASE(quantizelinear_test) ...@@ -2381,38 +2430,84 @@ TEST_CASE(quantizelinear_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {5}}); auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {5}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}}); auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {1}});
auto min_val = mm->add_literal(-128);
auto max_val = mm->add_literal(127);
auto l1_mbcast = auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div);
auto s = round->get_shape();
std::vector<int> min_data(s.elements(), 0);
std::vector<int> max_data(s.elements(), 255);
auto min_arg = mm->add_literal(s, min_data);
auto max_arg = mm->add_literal(s, max_data);
auto clip = mm->add_instruction(migraphx::make_op("clip"), round, min_arg, max_arg);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::uint8_type)}}),
clip);
auto prog = optimize_onnx("quantizelinear_test.onnx", true);
EXPECT(p.sort() == prog.sort());
}
TEST_CASE(quantizelinear_int32_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::int32_type, {5}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1);
l0 = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l0);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast); auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div); auto round = mm->add_instruction(migraphx::make_op("round"), div);
l2 = mm->add_instruction( auto s = round->get_shape();
std::vector<int> min_data(s.elements(), 0);
std::vector<int> max_data(s.elements(), 255);
auto min_arg = mm->add_literal(s, min_data);
auto max_arg = mm->add_literal(s, max_data);
auto clip = mm->add_instruction(migraphx::make_op("clip"), round, min_arg, max_arg);
mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::uint8_type)}}),
l2); clip);
auto prog = optimize_onnx("quantizelinear_int32_test.onnx", true);
EXPECT(p.sort() == prog.sort());
}
TEST_CASE(quantizelinear_zero_point_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {5}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {1}});
auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div);
auto l2_mbcast = auto l2_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l2); mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l2);
l2_mbcast = mm->add_instruction(
round = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
round); l2_mbcast);
auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_mbcast); auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_mbcast);
min_val = auto s = round->get_shape();
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), min_val); std::vector<int> min_data(s.elements(), -128);
max_val = std::vector<int> max_data(s.elements(), 127);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), max_val); auto min_arg = mm->add_literal(s, min_data);
auto clip = mm->add_instruction(migraphx::make_op("clip"), add, min_val, max_val); auto max_arg = mm->add_literal(s, max_data);
auto clip = mm->add_instruction(migraphx::make_op("clip"), add, min_arg, max_arg);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
clip); clip);
auto prog = optimize_onnx("quantizelinear_test.onnx"); auto prog = optimize_onnx("quantizelinear_zero_point_test.onnx", true);
EXPECT(p.sort() == prog.sort()); EXPECT(p.sort() == prog.sort());
} }
...@@ -2426,9 +2521,6 @@ migraphx::program make_quantizelinear_axis_prog() ...@@ -2426,9 +2521,6 @@ migraphx::program make_quantizelinear_axis_prog()
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, input_lens}); auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, input_lens});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {5}}); auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {5}});
auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {5}}); auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {5}});
auto min_val = mm->add_literal(-128);
auto max_val = mm->add_literal(127);
auto l1_bcast = mm->add_instruction( auto l1_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), l1); migraphx::make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), l1);
...@@ -2438,18 +2530,15 @@ migraphx::program make_quantizelinear_axis_prog() ...@@ -2438,18 +2530,15 @@ migraphx::program make_quantizelinear_axis_prog()
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), l2); migraphx::make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), l2);
l2_bcast = mm->add_instruction( l2_bcast = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l2_bcast); l2_bcast);
round = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
round);
auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_bcast); auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_bcast);
min_val = mm->add_instruction( auto s = round->get_shape();
migraphx::make_op("multibroadcast", {{"output_lens", {1, 1, 5, 1}}}), min_val); std::vector<int> min_data(s.elements(), -128);
max_val = mm->add_instruction( std::vector<int> max_data(s.elements(), 127);
migraphx::make_op("multibroadcast", {{"output_lens", {1, 1, 5, 1}}}), max_val); auto min_arg = mm->add_literal(s, min_data);
auto clip = mm->add_instruction(migraphx::make_op("clip"), add, min_val, max_val); auto max_arg = mm->add_literal(s, max_data);
auto clip = mm->add_instruction(migraphx::make_op("clip"), add, min_arg, max_arg);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
...@@ -2461,7 +2550,7 @@ TEST_CASE(quantizelinear_axis_test) ...@@ -2461,7 +2550,7 @@ TEST_CASE(quantizelinear_axis_test)
{ {
migraphx::program p = make_quantizelinear_axis_prog(); migraphx::program p = make_quantizelinear_axis_prog();
auto prog = optimize_onnx("quantizelinear_axis_test.onnx"); auto prog = optimize_onnx("quantizelinear_axis_test.onnx", true);
EXPECT(p.sort() == prog.sort()); EXPECT(p.sort() == prog.sort());
} }
...@@ -2469,7 +2558,7 @@ TEST_CASE(quantizelinear_neg_axis_test) ...@@ -2469,7 +2558,7 @@ TEST_CASE(quantizelinear_neg_axis_test)
{ {
migraphx::program p = make_quantizelinear_axis_prog(); migraphx::program p = make_quantizelinear_axis_prog();
auto prog = optimize_onnx("quantizelinear_neg_axis_test.onnx"); auto prog = optimize_onnx("quantizelinear_neg_axis_test.onnx", true);
EXPECT(p.sort() == prog.sort()); EXPECT(p.sort() == prog.sort());
} }
......
...@@ -23,4 +23,4 @@ ...@@ -23,4 +23,4 @@
 
 
 
B B
\ No newline at end of file \ No newline at end of file
quantizelinear_int32_test:m

0
1out"QuantizeLinearquantizelinear_int32_testZ
0

Z
1

b
out

B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment