Commit d64ab071 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

temp code backup.

parent 852a517a
...@@ -20,6 +20,8 @@ namespace op { ...@@ -20,6 +20,8 @@ namespace op {
struct convert : unary<convert> struct convert : unary<convert>
{ {
shape::type_t target_type = shape::half_type; shape::type_t target_type = shape::half_type;
float scale = 1.0f;
float shift = 0.0f;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -35,10 +37,11 @@ struct convert : unary<convert> ...@@ -35,10 +37,11 @@ struct convert : unary<convert>
auto apply() const auto apply() const
{ {
return [](auto x) { return x; }; return [&](auto x) { return scale * x + shift; };
} }
convert(shape::type_t t) : target_type{t} {} convert(shape::type_t t) : target_type{t} {}
convert(shape::type_t t, float sle, float sft) : target_type{t}, scale {sle}, shift{sft} {}
convert() {} convert() {}
}; };
......
#ifndef MIGRAPHX_GUARD_OPERATORS_CONVERT_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZATION_HPP
#define MIGRAPHX_GUARD_OPERATORS_CONVERT_HPP #define MIGRAPHX_GUARD_RTGLIB_QUANTIZATION_HPP
#include <array> #include <string>
#include <migraphx/op/unary.hpp> #include <vector>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct convert : unary<convert> struct program;
{
shape::type_t target_type = shape::half_type;
float scale = 1.0f;
float shift = 0.0f;
template <class Self, class F> void quantize(program& prog, const std::vector<std::string>& ins_names);
static auto reflect(Self& self, F f) void quantize(program& prog);
{
return pack(
f(self.target_type, "target_type"), f(self.scale, "scale"), f(self.shift, "shift"));
}
shape compute_shape(std::vector<shape> inputs) const void quantize_int8(program& prog, const std::vector<std::string>& ins_names);
{
check_shapes{inputs, *this}.has(1);
return {target_type, inputs.at(0).lens(), inputs.at(0).strides()};
}
auto apply() const
{
// return [&](auto x) { return (target_type == shape::int8_type) ? static_cast<int8_t>(x *
// scale + shift) : x; };
return [&](auto x) { return scale * x + shift; };
}
convert(shape::type_t t) : target_type{t} {}
convert(shape::type_t t, float sle, float sft) : target_type{t}, scale(sle), shift(sft) {}
convert() {}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -3,6 +3,11 @@ ...@@ -3,6 +3,11 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/op/convert.hpp> #include <migraphx/op/convert.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <utility> #include <utility>
...@@ -26,7 +31,7 @@ instruction_ref insert_quant_ins(program& prog, ...@@ -26,7 +31,7 @@ instruction_ref insert_quant_ins(program& prog,
ins->get_shape().type() == shape::double_type || ins->get_shape().type() == shape::double_type ||
ins->get_shape().type() == shape::int32_type); ins->get_shape().type() == shape::int32_type);
instruction_ref quant_ins{}; instruction_ref quant_ins{};
quant_ins = prog.insert_instruction(std::next(ins), op::convert{type}, ins); quant_ins = prog.insert_instruction(std::next(ins), op::convert{type, scale, shift}, ins);
map_ins[ins] = quant_ins; map_ins[ins] = quant_ins;
return quant_ins; return quant_ins;
...@@ -112,7 +117,7 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names) ...@@ -112,7 +117,7 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
// For now, we only support the int8 quantization of gemm and convolution // For now, we only support the int8 quantization of gemm and convolution
std::vector<std::string> op_names = {"dot", "convolution"}; std::vector<std::string> op_names = {"dot", "convolution"};
if(!std::all_of(ins_names.begin(), ins_names.end(), [&](auto name) { if(!std::all_of(ins_names.begin(), ins_names.end(), [&](auto name) {
return std::find(op_names.begin(), op_names.end(), name); return (std::find(op_names.begin(), op_names.end(), name) != op_names.end());
})) }))
{ {
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation"); MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
...@@ -209,7 +214,36 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names) ...@@ -209,7 +214,36 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
// When converting from other types to int8_type, there are parameters // When converting from other types to int8_type, there are parameters
// used as scale and shift(.0f), which will generate results diffrent from // used as scale and shift(.0f), which will generate results diffrent from
// the original results. To adjust the output to be "correct(approximatly // the original results. To adjust the output to be "correct(approximatly
// equal)", we need additional calculation for that. // equal)", we need additional calculation for the adjustment
if (ins->name() == "dot")
{
auto dot_op = any_cast<op::dot>(ins->get_operator());
int32_t quant_alpha = static_cast<int32_t>(dot_op.alpha / (int8_param[0].first * int8_param[1].first) + 0.5f);
int32_t quant_beta = static_cast<int32_t>(dot_op.beta + 0.5f);
prog.replace_instruction(ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
}
else if (ins->name() == "convolution")
{
// Current MIOpen convolution does not support alpha and beta,
// so we need a separate multiply to adjust the output
auto conv_op = any_cast<op::convolution>(ins->get_operator());
auto padding = conv_op.padding;
auto stride = conv_op.stride;
auto dilation = conv_op.dilation;
auto padding_mode = conv_op.padding_mode;
auto group = conv_op.group;
auto adjust_factor = 1.0 / int8_param[0].first * int8_param[1].first;
auto conv_res = prog.insert_instruction(ins, op::quant_convolution{padding, stride, dilation, padding_mode, group}, converted_inputs);
auto conv_lens = conv_res->get_shape().lens();
auto fl = prog.add_literal(literal(adjust_factor));
auto adj_fact = prog.insert_instruction(ins, op::multibroadcast{conv_lens}, fl);
prog.replace_instruction(ins, adj_fact);
}
else
{
MIGRAPHX_THROW("INT8_QUANTIZE: does not support operator" + ins->name());
}
prog.replace_instruction(ins, op, converted_inputs); prog.replace_instruction(ins, op, converted_inputs);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment