Unverified Commit b45f7239 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

qdq for quantization and include subgraph (#891)



Add operators, refactor parsers, add rewrite passes, add tests
Add ref implementations
Move broadcasting of scales and zero points to onnx parser
Allow for x and zero_point to have different types in quantizelinear; fix zero_point default type
fp16 and fp8 quantization to include subgraph and parameters
fix unit test to use qdq operators for int8 quantization
Co-authored-by: default avatarturneram <alturner@amd.com>
parent fdaa21ee
......@@ -47,6 +47,8 @@ add_library(migraphx
program.cpp
propagate_constant.cpp
quantization.cpp
quantize_fp16.cpp
quantize_int8.cpp
reduce_dims.cpp
register_op.cpp
register_target.cpp
......
......@@ -8,6 +8,7 @@
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/context.hpp>
#include <cmath>
#include <utility>
......@@ -29,7 +30,9 @@ struct capture
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(const shape&, std::vector<argument> args) const
// the context argument is added to prevent the op from be eliminated by
// constant propagation
argument compute(context&, const shape&, const std::vector<argument>& args) const
{
if(f)
{
......@@ -42,6 +45,8 @@ struct capture
return args.front();
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
......
......@@ -26,7 +26,7 @@ struct dequantizelinear
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.same_dims();
return {shape::float_type, inputs[0].lens(), inputs[0].strides()};
return {inputs[1].type(), inputs[0].lens(), inputs[0].strides()};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
......
......@@ -17,32 +17,10 @@ struct program;
void quantize_fp16(program& prog, const std::vector<std::string>& ins_names = {"all"});
// insert the capture operator for the inputs of each operator to be quantized
// to int8
std::size_t capture_arguments(program& prog,
const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func);
std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments_impl(program& prog, const target& t, const std::vector<std::string>& ins_names);
template <class T>
std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments(program& prog, T&& t, const std::vector<std::string>& ins_names)
{
static_assert(std::is_same<std::remove_cv_t<std::remove_reference_t<T>>, target>{} &&
std::is_lvalue_reference<T>{},
"Dangling reference to target!");
return capture_arguments_impl(prog, t, ins_names);
}
void quantize_int8(program& prog,
const target& t,
const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names = {"dot", "convolution"});
void quantize_int8_impl(program& prog,
const std::vector<std::pair<float, float>>& quant_params,
const std::vector<std::string>& ins_names);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_FP16_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_FP16_HPP
#include <string>
#include <vector>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct module;
/**
* quantize a program to fp16
*/
struct quantize_fp16_pass
{
std::vector<std::string> ins_names = {"all"};
std::string name() const { return "quantize_fp16"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT8_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT8_HPP
#include <string>
#include <vector>
#include <functional>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct module;
/**
* capture inputs of operators to be quantized to int8
*/
struct capture_arguments_pass
{
std::vector<std::string> ins_names = {"dot", "convolution"};
std::function<void(std::size_t, std::vector<argument>)> f{};
std::size_t* param_index = nullptr;
std::string name() const { return "capture_arguments"; }
void apply(module& m) const;
};
/**
* quantize a program to int8
*/
struct quantize_int8_pass
{
std::vector<std::string> ins_names = {"dot", "convolution"};
std::vector<std::pair<float, float>> quant_params;
std::string name() const { return "quantize_int8"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -763,6 +763,22 @@ void program::remove_module(const std::string& name)
impl->modules.at(name).end(),
[&](auto&& ins) { return references_instruction(impl->modules, ins, name); }) &&
"Instruction referenced in another module");
// if an instruction has an input out side of the current module, need to remove
// the instruction from its input's outputs
auto& mod = impl->modules.at(name);
for(auto ins : iterator_for(mod))
{
auto inputs = ins->inputs();
for(auto in : inputs)
{
if(not mod.has_instruction(in))
{
in->remove_output(ins);
}
}
}
impl->modules.erase(name);
}
......
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/quantize_int8.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/op/convert.hpp>
#include <migraphx/op/clip.hpp>
#include <migraphx/op/round.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/target.hpp>
#include <utility>
#include <set>
#include <iomanip>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <fstream>
#include <algorithm>
#include <migraphx/pass_manager.hpp>
#include <set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
instruction_ref insert_quant_ins(module& modl,
instruction_ref& ins,
shape::type_t type,
std::unordered_map<instruction_ref, instruction_ref>& map_ins,
float scale = 1.0f,
float shift = 0.0f)
{
if(map_ins.count(ins) > 0)
{
return map_ins[ins];
}
if(ins->name() == "undefined")
{
return ins;
}
assert(ins->get_shape().type() == shape::float_type or
ins->get_shape().type() == shape::double_type or
ins->get_shape().type() == shape::int32_type or
ins->get_shape().type() == shape::half_type);
instruction_ref quant_ins{};
auto insert_loc = std::next(ins);
if(type == shape::int8_type)
{
auto scaled_ins = ins;
if(scale != 1.0f)
{
auto float_ins = scaled_ins;
if(scaled_ins->get_shape().type() != shape::float_type)
{
float_ins = modl.insert_instruction(
insert_loc,
make_op("convert", {{"target_type", to_value(shape::float_type)}}),
scaled_ins);
}
std::vector<float> vec_scale(scaled_ins->get_shape().elements(), scale);
auto l_scale = modl.add_literal(literal(float_ins->get_shape(), vec_scale));
scaled_ins = modl.insert_instruction(insert_loc, make_op("mul"), l_scale, float_ins);
}
auto shifted_ins = scaled_ins;
if(shift != 0.0f)
{
auto float_ins = shifted_ins;
if(shifted_ins->get_shape().type() != shape::float_type)
{
float_ins = modl.insert_instruction(
insert_loc,
make_op("convert", {{"target_type", to_value(shape::float_type)}}),
shifted_ins);
}
std::vector<float> vec_shift(shifted_ins->get_shape().elements(), shift);
auto l_shift = modl.add_literal(literal(float_ins->get_shape(), vec_shift));
shifted_ins = modl.insert_instruction(insert_loc, make_op("add"), l_shift, float_ins);
}
auto rounded_ins = modl.insert_instruction(insert_loc, make_op("round"), shifted_ins);
auto rounded_lens = rounded_ins->get_shape().lens();
auto max_clip = modl.add_literal(127.0f);
auto min_clip = modl.add_literal(-128.0f);
max_clip = modl.insert_instruction(
insert_loc, make_op("multibroadcast", {{"out_lens", rounded_lens}}), max_clip);
min_clip = modl.insert_instruction(
insert_loc, make_op("multibroadcast", {{"out_lens", rounded_lens}}), min_clip);
auto clipped_ins =
modl.insert_instruction(insert_loc, make_op("clip"), rounded_ins, min_clip, max_clip);
quant_ins = modl.insert_instruction(
insert_loc, make_op("convert", {{"target_type", type}}), clipped_ins);
}
else
{
quant_ins =
modl.insert_instruction(insert_loc, make_op("convert", {{"target_type", type}}), ins);
}
map_ins[ins] = quant_ins;
return quant_ins;
}
// This function is to convert any instructions specified in the input
// from double or float to float16 by inserting a convert operator.
// For the conversion, there could be cases of overflowing, but it
......@@ -119,337 +30,14 @@ instruction_ref insert_quant_ins(module& modl,
// truncate of the input to get the fp16.
void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
{
auto* mm = prog.get_main_module();
std::unordered_map<instruction_ref, instruction_ref> map_fp16;
for(auto ins : iterator_for(*mm))
{
if(ins->name() == "@return")
break;
// all indicates every instruction is converted
if((not contains(ins_names, "all")) and (not contains(ins_names, ins->name())))
{
continue;
}
shape::type_t orig_type = ins->get_shape().type();
// process all inputs, if input is a fp32 or fp64, convert it
// to a fp16 by adding a convert operator.
auto inputs = ins->inputs();
std::vector<instruction_ref> converted_inputs;
for(auto input : inputs)
{
auto s = input->get_shape();
if(s.type() == shape::float_type || s.type() == shape::double_type)
{
// if the input is a convert operator, uses its input
// as its current input
instruction_ref input_fp16{};
if(input->name() == "convert" and
input->inputs().front()->get_shape().type() == shape::half_type)
{
input_fp16 = input->inputs().front();
}
else
{
input_fp16 = insert_quant_ins(*mm, input, shape::half_type, map_fp16);
}
converted_inputs.push_back(input_fp16);
}
else
{
converted_inputs.push_back(input);
}
}
// no change for the input, go to the next instruction
if(inputs == converted_inputs)
{
continue;
}
auto op = ins->get_operator();
auto ins_shape = compute_shape(op, converted_inputs);
if(ins_shape.type() != orig_type)
{
// check the dead code case to avoid assert
bool output_empty = ins->outputs().empty();
auto ins_orig_type = mm->insert_instruction(
std::next(ins), make_op("convert", {{"target_type", orig_type}}), ins);
if(!output_empty)
{
mm->replace_instruction(ins, ins_orig_type);
}
}
mm->replace_instruction(ins, op, converted_inputs);
}
}
static void ins_quantize_int8(module& modl,
instruction_ref ins,
std::vector<instruction_ref>& converted_inputs,
const std::vector<std::pair<float, float>>& ins_quant_params)
{
auto orig_type = ins->get_shape().type();
auto inputs = ins->inputs();
if(ins->name() == "dot")
{
auto dot_op = any_cast<op::dot>(ins->get_operator());
float new_alpha = dot_op.alpha / (ins_quant_params[0].first * ins_quant_params[1].first);
float new_beta = dot_op.beta;
// We need additional checking about the quant_alpha value. If
// abs(quant_alpha) > 50 (some tmp value set here), we can convert
// it to an integer as the new_alpha in the quant_dot
float threshold = 50.0f;
if(fabs(new_alpha) >= threshold && fabs(new_beta) >= threshold)
{
int32_t quant_alpha = static_cast<int32_t>(std::round(new_alpha));
int32_t quant_beta = static_cast<int32_t>(std::round(new_beta));
if(shape::int32_type == orig_type)
{
modl.replace_instruction(
ins,
make_op("quant_dot", {{"alpha", quant_alpha}, {"beta", quant_beta}}),
converted_inputs);
}
else
{
auto quant_dot = modl.insert_instruction(
ins,
make_op("quant_dot", {{"alpha", quant_alpha}, {"beta", quant_beta}}),
converted_inputs);
modl.replace_instruction(
ins, make_op("convert", {{"target_type", to_value(orig_type)}}), quant_dot);
}
}
// either alpha or beta cannot be quantized because of too big
// relative rounding error
else
{
if(converted_inputs.size() == 3)
{
converted_inputs.pop_back();
}
auto q_dot = modl.insert_instruction(
ins, make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), converted_inputs);
auto f_dot = modl.insert_instruction(
ins, make_op("convert", {{"target_type", to_value(shape::float_type)}}), q_dot);
auto c_shape = q_dot->get_shape();
std::vector<float> vec_alpha(c_shape.elements(), new_alpha);
auto l_alpha =
modl.add_literal(literal({shape::float_type, c_shape.lens()}, vec_alpha));
if(inputs.size() == 3 and dot_op.beta != 0.0f)
{
auto alpha_ab = modl.insert_instruction(ins, make_op("mul"), l_alpha, f_dot);
std::vector<float> vec_beta(c_shape.elements(), dot_op.beta);
auto l_beta =
modl.add_literal(literal({shape::float_type, c_shape.lens()}, vec_beta));
instruction_ref beta_c{};
if(orig_type != shape::float_type)
{
auto fp32_c = modl.insert_instruction(
ins,
make_op("convert", {{"target_type", to_value(shape::float_type)}}),
inputs.back());
beta_c = modl.insert_instruction(ins, make_op("mul"), l_beta, fp32_c);
}
else
{
beta_c = modl.insert_instruction(ins, make_op("mul"), l_beta, inputs.back());
}
if(orig_type == shape::float_type)
{
modl.replace_instruction(ins, make_op("add"), alpha_ab, beta_c);
}
else
{
auto f_res = modl.insert_instruction(ins, make_op("add"), alpha_ab, beta_c);
modl.replace_instruction(
ins, make_op("convert", {{"target_type", to_value(orig_type)}}), f_res);
}
}
else
{
if(orig_type == shape::float_type)
{
modl.replace_instruction(ins, make_op("mul"), l_alpha, f_dot);
}
else
{
auto alpha_ab = modl.insert_instruction(ins, make_op("mul"), l_alpha, f_dot);
modl.replace_instruction(
ins, make_op("convert", {{"target_type", to_value(orig_type)}}), alpha_ab);
}
}
}
}
else if(ins->name() == "convolution")
{
// Current MIOpen convolution does not support alpha and beta,
// so we need a separate multiply to adjust the output
auto conv_op = any_cast<op::convolution>(ins->get_operator());
auto padding = conv_op.padding;
auto stride = conv_op.stride;
auto dilation = conv_op.dilation;
auto padding_mode = conv_op.padding_mode;
auto group = conv_op.group;
auto adjust_factor = 1.0f / (ins_quant_params[0].first * ins_quant_params[1].first);
auto quant_conv = modl.insert_instruction(
ins,
op::quant_convolution{padding, stride, dilation, padding_mode, group},
converted_inputs);
float threshold = 50.0f;
std::vector<float> vec_factor(quant_conv->get_shape().elements(), adjust_factor);
if(quant_conv->get_shape().type() == orig_type and adjust_factor >= threshold)
{
auto l_factor = modl.add_literal(
literal(quant_conv->get_shape(), vec_factor.begin(), vec_factor.end()));
modl.replace_instruction(ins, make_op("mul"), quant_conv, l_factor);
}
// convert quant_conv output to float type, multiply the factor and
// conver back to original type
else
{
auto float_conv = modl.insert_instruction(
ins,
make_op("convert", {{"target_type", to_value(shape::float_type)}}),
quant_conv);
auto l_factor = modl.add_literal(literal(float_conv->get_shape(), vec_factor));
if(orig_type == shape::float_type)
{
modl.replace_instruction(ins, make_op("mul"), l_factor, float_conv);
}
else
{
auto adjusted_conv =
modl.insert_instruction(ins, make_op("mul"), l_factor, float_conv);
modl.replace_instruction(
ins, make_op("convert", {{"target_type", to_value(orig_type)}}), adjusted_conv);
}
}
}
else
{
MIGRAPHX_THROW("QUANTIZE_INT8: does not support operator " + ins->name());
}
}
// int8 quantization is different from fp16 since int8 can only handle value
// -128 ~ 127. To convert the float or double to int8, we need a scale and
// a shift, then the convert can be done as v_int8 = fp * scale + shift.
// To simplify the changes, we consider shift as 0.0f for now.
void quantize_int8_impl(program& prog,
const std::vector<std::pair<float, float>>& quant_params,
const std::vector<std::string>& ins_names)
{
if(enabled(MIGRAPHX_INT8_QUANTIZATION_PARAMS{}))
{
for(std::size_t i = 0; i < quant_params.size(); ++i)
{
auto param = quant_params.at(i);
std::cout << "ins_index = " << i << ", scale = " << param.first
<< ", shift = " << param.second << std::endl;
}
std::cout << std::endl;
}
// For now, we only support the int8 quantization of gemm and convolution
std::set<std::string> op_names = {"convolution", "dot"};
std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end());
if(!std::includes(
op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
{
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
}
auto* mm = prog.get_main_module();
std::size_t quant_param_index = 0;
std::unordered_map<instruction_ref, instruction_ref> map_quant_ins;
std::unordered_map<instruction_ref, std::size_t> map_ins_index;
for(auto ins : iterator_for(*mm))
{
if(ins->name() == "@return")
break;
if(not contains(ins_names, ins->name()))
{
continue;
}
// for the dot operator, there could be 2 or 3 input arguments
// if the 3rd argument is available, convert it to an int32.
std::vector<instruction_ref> converted_inputs;
// process all inputs, if input is a fp32 or fp64, convert it
// to a int8 type by adding a convert operator and replace
// the operator with the corresponding int8 version
auto inputs = ins->inputs();
std::vector<std::pair<float, float>> ins_quant_params;
for(auto input : inputs)
{
// calculate the index of each instruction to be quantized
std::size_t ins_index =
(map_ins_index.count(input) > 0) ? map_ins_index[input] : quant_param_index++;
map_ins_index[input] = ins_index;
auto param = quant_params[map_ins_index[input]];
ins_quant_params.push_back(param);
// In general, the target_type is int8, but for the dot
// operation, if it has 3 inputs, then the last one should
// be converted to int32_type
shape::type_t quant_type = shape::int8_type;
if((ins->name() == "dot") and (inputs.size() == 3) and (input == inputs.back()))
{
quant_type = shape::int32_type;
}
auto s = input->get_shape();
if((s.type() == shape::float_type or s.type() == shape::double_type or
s.type() == shape::half_type or s.type() == shape::int32_type) and
s.type() != quant_type)
{
// if the input is a convert operator, uses its input
// as its current input
instruction_ref quant_input{};
if(input->name() == "convert" and
input->inputs().front()->get_shape().type() == quant_type)
{
quant_input = input->inputs().front();
// the scale in this case is not used, so tune the scale
// to 1.0f for this parameter
ins_quant_params.back() = std::pair<float, float>(1.0f, 0.0f);
}
else
{
quant_input = insert_quant_ins(
*mm, input, quant_type, map_quant_ins, param.first, param.second);
}
converted_inputs.push_back(quant_input);
}
else
{
converted_inputs.push_back(input);
}
}
// no change for the input, go to the next instruction
if(inputs == converted_inputs)
{
continue;
}
ins_quantize_int8(*mm, ins, converted_inputs, ins_quant_params);
}
if(quant_param_index != quant_params.size())
{
MIGRAPHX_THROW("QUANTIZE_INT8: number of scales does not match");
}
run_passes(prog,
{quantize_fp16_pass{ins_names},
eliminate_common_subexpression{},
dead_code_elimination{},
simplify_reshapes{},
dead_code_elimination{},
simplify_qdq{},
dead_code_elimination{}});
}
void quantize_int8(program& prog,
......@@ -457,87 +45,14 @@ void quantize_int8(program& prog,
const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names)
{
// insert capture operator
auto cap_prog = prog;
auto int8_quant_params = capture_arguments(cap_prog, t, ins_names);
// use the calibration data to compute the quantization scale
cap_prog.compile(t);
// use all calibration data to run the program to calculate the
// quantization scale and shift
for(auto&& arg : calibration)
{
parameter_map m;
for(auto&& x : cap_prog.get_parameter_shapes())
{
if(arg.count(x.first) > 0)
{
assert(x.second == arg.at(x.first).get_shape());
m[x.first] = t.copy_to(arg.at(x.first));
}
else
{
m[x.first] = t.allocate(x.second);
}
}
cap_prog.eval(m);
}
quantize_int8_impl(prog, *int8_quant_params, ins_names);
}
// For the input of each input argument, we need to insert a
// capture operator to compute the scale and shift
std::size_t capture_arguments(program& prog,
const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func)
{
auto* mm = prog.get_main_module();
size_t num_quant_params = 0;
// the int8 quantization only support dot and convolution
std::set<std::string> op_names = {"dot", "convolution"};
std::set<std::string> op_names = {"convolution", "dot"};
std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end());
if(!std::includes(
op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
{
MIGRAPHX_THROW("CAPTURE_ARGUMENTS: input operator is not supported");
}
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(*mm))
{
if(not contains(ins_names, ins->name()))
{
continue;
}
auto inputs = ins->inputs();
std::vector<instruction_ref> new_args;
for(auto input : inputs)
{
instruction_ref new_ins{};
if(ins_map.count(input) > 0)
{
new_ins = ins_map[input];
}
else
{
new_ins = mm->insert_instruction(
std::next(input), op::capture{num_quant_params++, func}, input);
ins_map[input] = new_ins;
}
new_args.push_back(new_ins);
}
instruction::replace(ins, ins->get_operator(), ins->get_shape(), new_args);
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
}
return num_quant_params;
}
std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments_impl(program& prog, const target& t, const std::vector<std::string>& ins_names)
{
std::shared_ptr<std::vector<std::pair<float, float>>> int8_quant_params =
std::make_shared<std::vector<std::pair<float, float>>>();
std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>();
......@@ -545,7 +60,6 @@ capture_arguments_impl(program& prog, const target& t, const std::vector<std::st
auto calc_quant_params = [int8_quant_params, max_abs_vals, &t](std::size_t ins_index,
std::vector<argument> args) {
std::pair<float, float> param_pair{64.0f, 0.0f};
// scale and shift is need for only int8 type, and we do not
// consider shift, so set shift to 0
std::vector<float> vec_val;
......@@ -568,12 +82,56 @@ capture_arguments_impl(program& prog, const target& t, const std::vector<std::st
int8_quant_params->at(ins_index) = param_pair;
};
auto num_params = capture_arguments(prog, ins_names, calc_quant_params);
// pass to add capture argument op
std::size_t param_num = 0;
run_passes(prog, {capture_arguments_pass{ins_names, calc_quant_params, &param_num}});
int8_quant_params->resize(param_num, std::pair<float, float>(64.0f, 0.0f));
max_abs_vals->resize(param_num, 0.0f);
// use the calibration data to compute the quantization scale
auto capture_prog = prog;
capture_prog.compile(t);
// use all calibration data to run the program to calculate the
// quantization scale and shift
for(auto&& arg : calibration)
{
parameter_map m;
for(auto&& x : capture_prog.get_parameter_shapes())
{
if(arg.count(x.first) > 0)
{
assert(x.second == arg.at(x.first).get_shape());
m[x.first] = t.copy_to(arg.at(x.first));
}
else
{
m[x.first] = t.allocate(x.second);
}
}
capture_prog.eval(m);
}
int8_quant_params->resize(num_params, std::pair<float, float>(64.0f, 0.0f));
max_abs_vals->resize(num_params, 0.0f);
// print the quantization parameters in only the main module
if(enabled(MIGRAPHX_INT8_QUANTIZATION_PARAMS{}))
{
for(std::size_t i = 0; i < int8_quant_params->size(); ++i)
{
auto param = int8_quant_params->at(i);
std::cout << "ins_index = " << i << ", scale = " << param.first
<< ", shift = " << param.second << std::endl;
}
std::cout << std::endl;
}
return int8_quant_params;
run_passes(prog,
{quantize_int8_pass{ins_names, *int8_quant_params},
eliminate_common_subexpression{},
dead_code_elimination{},
simplify_reshapes{},
dead_code_elimination{},
simplify_qdq{},
dead_code_elimination{}});
}
} // namespace MIGRAPHX_INLINE_NS
......
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/target.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static void quantize_module(module& m, const std::vector<std::string>& ins_names)
{
for(auto ins : iterator_for(m))
{
// instructions are not in the set to be quantized
if(not(contains(ins_names, ins->name()) or contains(ins_names, "all")))
continue;
// skip return and convert instructions
if(contains({"@return", "convert"}, ins->name()))
continue;
if(ins->inputs().empty())
continue;
auto mod_inputs = ins->module_inputs();
auto s = ins->get_shape();
// Convert back to original type before quantizing the inputs
if(mod_inputs.empty())
{
auto r = m.insert_instruction(
std::next(ins), make_op("convert", {{"target_type", s.type()}}), ins);
m.replace_instruction(ins, r);
}
// Convert each of the inputs that are floating point to fp16
auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
auto input_type = input->get_shape().type();
if(input_type != shape::float_type and input_type != shape::double_type)
return input;
return m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::half_type}}), input);
});
// Replace inputs
m.replace_instruction(ins, ins->get_operator(), inputs, mod_inputs);
}
}
void quantize_fp16_pass::apply(module& m) const { quantize_module(m, ins_names); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/operation.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_int8.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/target.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <numeric>
#include <set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
static std::vector<shape::type_t>& get_quantizable_type()
{
static std::vector<shape::type_t> quantable_types = {
shape::float_type, shape::double_type, shape::half_type};
return quantable_types;
}
void quantize_int8_pass::apply(module& m) const // NOLINT
{
const auto& quantizable_types = get_quantizable_type();
for(auto ins : iterator_for(m))
{
if(ins->name() != "capture")
continue;
auto op_val = ins->get_operator().to_value();
assert(op_val.contains("ins_index"));
auto param_index = op_val.at("ins_index").to<std::size_t>();
auto param = quant_params[param_index];
auto input = ins->inputs().front();
auto s = input->get_shape();
if(contains(quantizable_types, s.type()) and s.type() != shape::int8_type)
{
auto zero_point = m.add_literal(static_cast<int8_t>(param.second));
auto scale = m.add_literal(literal({s.type()}, {1.0f / param.first}));
const auto& lens = s.lens();
scale =
m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), scale);
zero_point = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", lens}}), zero_point);
auto q_in =
m.insert_instruction(ins, make_op("quantizelinear"), input, scale, zero_point);
auto dq_in =
m.insert_instruction(ins, make_op("dequantizelinear"), q_in, scale, zero_point);
m.replace_instruction(ins, dq_in);
}
}
}
void capture_arguments_pass::apply(module& m) const // NOLINT
{
assert(param_index != nullptr);
for(auto ins : iterator_for(m))
{
if(not contains(ins_names, ins->name()))
{
continue;
}
auto inputs = ins->inputs();
std::vector<instruction_ref> new_args;
for(auto input : inputs)
{
auto new_in = m.insert_instruction(ins, op::capture{(*param_index)++, f}, input);
new_args.push_back(new_in);
}
m.replace_instruction(ins, ins->get_operator(), new_args);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -92,7 +92,8 @@ struct match_find_quantizable_ops
dq = m.insert_instruction(
qop, migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qop_args);
}
dq_scale = m.add_literal(static_cast<float>(scale));
auto ins_type = qop->get_shape().type();
dq_scale = m.add_literal(literal({ins_type}, {scale}));
auto lens = dq->get_shape().lens();
auto scale_mb =
......
......@@ -68,8 +68,7 @@ TEST_CASE(int8_quantization)
migraphx::shape sc{migraphx::shape::float_type, {5, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
mm->add_instruction(migraphx::op::dot{}, pa, pb, pc);
mm->add_instruction(migraphx::op::dot{}, pa, pb);
return p;
};
......@@ -82,7 +81,6 @@ TEST_CASE(int8_quantization)
migraphx::shape sc{migraphx::shape::float_type, {5, 8}};
m["a"] = migraphx::generate_argument(sa);
m["b"] = migraphx::generate_argument(sb);
m["c"] = migraphx::generate_argument(sc);
std::vector<float> ref_result;
migraphx::target ref_t = migraphx::ref::target{};
run_prog(p, ref_t, m, ref_result);
......
......@@ -7,46 +7,29 @@
#include <migraphx/ref/target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_int8.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include "test.hpp"
#include <migraphx/half.hpp>
migraphx::instruction_ref
create_clip_op(migraphx::program& p, float max, float min, migraphx::instruction_ref input)
{
auto* mm = p.get_main_module();
auto input_lens = input->get_shape().lens();
auto max_val = mm->add_literal(max);
auto min_val = mm->add_literal(min);
max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
max_val);
min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
min_val);
return mm->add_instruction(migraphx::make_op("clip"), input, min_val, max_val);
}
migraphx::instruction_ref create_clip_op(migraphx::instruction_ref insert_loc,
migraphx::program& p,
float max,
float min,
migraphx::instruction_ref input)
static void optimize_prog_int8(migraphx::program& prog)
{
auto* mm = p.get_main_module();
auto input_lens = input->get_shape().lens();
auto max_val = mm->add_literal(max);
auto min_val = mm->add_literal(min);
max_val = mm->insert_instruction(
insert_loc, migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), max_val);
min_val = mm->insert_instruction(
insert_loc, migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), min_val);
return mm->insert_instruction(insert_loc, migraphx::make_op("clip"), input, min_val, max_val);
migraphx::run_passes(prog,
{migraphx::simplify_qdq{},
migraphx::eliminate_common_subexpression{},
migraphx::dead_code_elimination{}});
}
TEST_CASE(param_add)
......@@ -71,9 +54,9 @@ TEST_CASE(param_add)
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto p1 = mm->add_parameter("x", s);
auto hp1 = mm->insert_instruction(std::next(p1), migraphx::make_op("convert"), p1);
auto p2 = mm->add_parameter("y", s);
auto hp2 = mm->insert_instruction(std::next(p2), migraphx::make_op("convert"), p2);
auto hp1 = mm->add_instruction(migraphx::make_op("convert"), p1);
auto hp2 = mm->add_instruction(migraphx::make_op("convert"), p2);
auto hs = mm->add_instruction(migraphx::make_op("add"), hp1, hp2);
auto res = mm->add_instruction(
migraphx::make_op("convert",
......@@ -130,7 +113,8 @@ TEST_CASE(param_add_sub)
auto p2 = mm->add_parameter("y", s);
auto sum = mm->add_instruction(migraphx::make_op("add"), p1, p2);
auto diff = mm->add_instruction(migraphx::make_op("sub"), sum, p2);
mm->add_instruction(migraphx::make_op("add"), diff, p1);
auto r = mm->add_instruction(migraphx::make_op("add"), diff, p1);
mm->add_return({r});
return p;
};
......@@ -140,32 +124,21 @@ TEST_CASE(param_add_sub)
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto p1 = mm->add_parameter("x", s);
auto hp1 = mm->insert_instruction(
std::next(p1),
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
p1);
auto p2 = mm->add_parameter("y", s);
auto hp2 = mm->insert_instruction(
std::next(p2),
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
p2);
auto hp1 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), p1);
auto hp2 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), p2);
auto hsum = mm->add_instruction(migraphx::make_op("add"), hp1, hp2);
auto sum = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
hsum);
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), hsum);
auto diff = mm->add_instruction(migraphx::make_op("sub"), sum, p2);
auto hdiff = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
diff);
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), diff);
auto res = mm->add_instruction(migraphx::make_op("add"), hdiff, hp1);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
res);
auto r = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), res);
mm->add_return({r});
return p;
};
......@@ -174,51 +147,18 @@ TEST_CASE(param_add_sub)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto p1 = mm->add_parameter("x", s);
auto p2 = mm->add_parameter("y", s);
auto hp2 = mm->insert_instruction(
std::next(p2),
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
p2);
auto p1 = mm->add_parameter("x", s);
auto p2 = mm->add_parameter("y", s);
auto sum = mm->add_instruction(migraphx::make_op("add"), p1, p2);
auto hsum = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
sum);
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), sum);
auto hp2 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), p2);
auto hdiff = mm->add_instruction(migraphx::make_op("sub"), hsum, hp2);
auto diff = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
hdiff);
mm->add_instruction(migraphx::make_op("add"), diff, p1);
return p;
};
auto create_program_half_all = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto p1 = mm->add_parameter("x", s);
auto hp1 = mm->insert_instruction(
std::next(p1),
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
p1);
auto p2 = mm->add_parameter("y", s);
auto hp2 = mm->insert_instruction(
std::next(p2),
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
p2);
auto hsum = mm->add_instruction(migraphx::make_op("add"), hp1, hp2);
auto hdiff = mm->add_instruction(migraphx::make_op("sub"), hsum, hp2);
auto hres = mm->add_instruction(migraphx::make_op("add"), hdiff, hp1);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
hres);
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), hdiff);
auto r = mm->add_instruction(migraphx::make_op("add"), diff, p1);
mm->add_return({r});
return p;
};
......@@ -236,17 +176,70 @@ TEST_CASE(param_add_sub)
auto p2 = create_program_half_sub();
migraphx::quantize_fp16(p1, {"sub"});
EXPECT(p1 == p2);
}
{
auto p1 = create_program_float();
auto p2 = create_program_half_all();
auto create_program_fp16 = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto p1 = mm->add_parameter("x", s);
auto p2 = mm->add_parameter("y", s);
auto hp1 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), p1);
auto hp2 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), p2);
auto hsum = mm->add_instruction(migraphx::make_op("add"), hp1, hp2);
auto sum = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), hsum);
auto hsum1 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), sum);
auto p3 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), p2);
auto diff = mm->add_instruction(migraphx::make_op("sub"), hsum1, p3);
auto fdiff = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), diff);
auto hdiff1 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), fdiff);
auto p4 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), p1);
auto res = mm->add_instruction(migraphx::make_op("add"), hdiff1, p4);
auto r = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), res);
mm->add_return({r});
return p;
};
auto create_program_quant_fp16 = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto p1 = mm->add_parameter("x", s);
auto p2 = mm->add_parameter("y", s);
auto hp1 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), p1);
auto hp2 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), p2);
auto hsum = mm->add_instruction(migraphx::make_op("add"), hp1, hp2);
auto hdiff = mm->add_instruction(migraphx::make_op("sub"), hsum, hp2);
auto hres = mm->add_instruction(migraphx::make_op("add"), hdiff, hp1);
auto r = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), hres);
mm->add_return({r});
return p;
};
auto p0 = create_program_float();
migraphx::run_passes(p0, {migraphx::quantize_fp16_pass{{"all"}}});
EXPECT(p0 == create_program_fp16());
auto p1 = create_program_float();
migraphx::quantize_fp16(p1);
migraphx::run_passes(*p1.get_main_module(), {migraphx::dead_code_elimination{}});
EXPECT(p1 == p2);
EXPECT(p1 == create_program_quant_fp16());
}
}
......@@ -308,13 +301,125 @@ TEST_CASE(literal_add)
}
}
TEST_CASE(op_capture)
TEST_CASE(fp16_subgraph)
{
auto test_func = [&](std::size_t ins_index, const std::vector<migraphx::argument>& args) {
(void)ins_index;
(void)args;
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {1}};
auto l1 = mm->add_literal(migraphx::literal(sd, {1}));
auto l2 = mm->add_literal(migraphx::literal(sd, {2}));
auto l3 = mm->add_literal(migraphx::literal(sd, {3}));
migraphx::shape sx{migraphx::shape::float_type, {1, 4}};
migraphx::shape sy{migraphx::shape::float_type, {3, 4}};
migraphx::shape sc{migraphx::shape::bool_type};
auto cond = mm->add_parameter("cond", sc);
auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sy);
auto* then_mod = p.create_module("If_6_if");
auto m1 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l1);
auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1);
auto m2 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l2);
auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2);
auto mfp16 = then_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), mul0);
then_mod->add_return({add0, mul0, mfp16});
auto* else_mod = p.create_module("If_6_else");
auto me1 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l3);
auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1);
auto me2 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l3);
auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2);
auto afp16 = else_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), add1);
else_mod->add_return({mul1, add1, afp16});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
auto r16 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), ret);
mm->add_return({r0, r1, r16});
return p;
};
auto create_fp16_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {1}};
auto l1 = mm->add_literal(migraphx::literal(sd, {1}));
auto l2 = mm->add_literal(migraphx::literal(sd, {2}));
auto l3 = mm->add_literal(migraphx::literal(sd, {3}));
migraphx::shape sx{migraphx::shape::float_type, {1, 4}};
migraphx::shape sy{migraphx::shape::float_type, {3, 4}};
migraphx::shape sc{migraphx::shape::bool_type};
auto cond = mm->add_parameter("cond", sc);
auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sy);
auto* then_mod = p.create_module("If_6_if");
auto hl1 = then_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), l1);
auto mhl1 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), hl1);
auto hx = then_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), x);
auto ad = then_mod->add_instruction(migraphx::make_op("add"), hx, mhl1);
auto fad = then_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), ad);
auto hl2 = then_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), l2);
auto mhl2 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), hl2);
auto hy1 = then_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), y);
auto mu = then_mod->add_instruction(migraphx::make_op("mul"), hy1, mhl2);
auto fmu = then_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), mu);
then_mod->add_return({fad, fmu, mu});
auto* else_mod = p.create_module("If_6_else");
auto hl3 = else_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), l3);
auto mhl3 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), hl3);
auto hx2 = else_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), x);
auto mu1 = else_mod->add_instruction(migraphx::make_op("mul"), hx2, mhl3);
auto fmu1 = else_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), mu1);
auto mhl4 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), hl3);
auto hy = else_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), y);
auto ad1 = else_mod->add_instruction(migraphx::make_op("add"), hy, mhl4);
auto fad1 = else_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), ad1);
else_mod->add_return({fmu1, fad1, ad1});
auto iff = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), iff);
auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), iff);
auto r2 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), iff);
mm->add_return({r0, r1, r2});
return p;
};
auto p1 = create_program();
migraphx::quantize_fp16(p1);
auto p2 = create_fp16_program();
EXPECT(p1 == p2);
}
TEST_CASE(op_capture)
{
auto create_program_float = [] {
migraphx::program p;
auto* mm = p.get_main_module();
......@@ -343,186 +448,105 @@ TEST_CASE(op_capture)
auto pb = mm->add_parameter("b", s2);
auto pc = mm->add_parameter("c", s2);
auto pa = mm->add_instruction(migraphx::make_op("add"), p1, p2);
auto opb = mm->insert_instruction(std::next(pb), migraphx::op::capture{1, test_func}, pb);
auto opc = mm->insert_instruction(std::next(pc), migraphx::op::capture{2, test_func}, pc);
auto opa = mm->add_instruction(migraphx::op::capture{0, test_func}, pa);
auto opa = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 0}}), pa);
auto opb = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 1}}), pb);
auto opc = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 2}}), pc);
auto ps = mm->add_instruction(migraphx::make_op("dot"), opa, opb, opc);
auto ops = mm->add_instruction(migraphx::op::capture{3, test_func}, ps);
mm->add_instruction(migraphx::make_op("dot"), opa, ops);
auto opm = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 3}}), pa);
auto ops = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 4}}), ps);
mm->add_instruction(migraphx::make_op("dot"), opm, ops);
return p;
};
{
auto p = create_program_float();
auto op_capture_p = create_program_op();
migraphx::target t = migraphx::ref::target{};
migraphx::capture_arguments(p, t, {"dot", "convolution"});
auto p = create_program_float();
auto op_capture_p = create_program_op();
migraphx::target t = migraphx::ref::target{};
std::size_t param_index = 0;
migraphx::run_passes(
p, {migraphx::capture_arguments_pass{{"dot", "convolution"}, {}, &param_index}});
EXPECT(p == op_capture_p);
}
}
TEST_CASE(dot_float)
TEST_CASE(op_capture_subgraph)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 2.0f}, {"beta", 1.5f}}), pa, pb, pc);
return p;
};
auto create_int8_quantized_prog = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
// quantize parameter a to int8 type, multiply the scale
std::vector<float> vfa(sa.elements(), 0.1f);
auto fa = mm->add_literal(migraphx::literal(sa, vfa));
auto ma = mm->add_instruction(migraphx::make_op("mul"), fa, pa);
auto ra = mm->add_instruction(migraphx::make_op("round"), ma);
auto ca = create_clip_op(p, 127.0f, -128.0f, ra);
auto qa = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
ca);
// quantize parameter b to int8 type
auto insert_loc = std::next(pb);
std::vector<float> vfb(sb.elements(), 0.1f);
auto fb = mm->add_literal(migraphx::literal(sb, vfb));
auto mb = mm->insert_instruction(insert_loc, migraphx::make_op("mul"), fb, pb);
auto rb = mm->insert_instruction(insert_loc, migraphx::make_op("round"), mb);
auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
auto qb = mm->insert_instruction(
insert_loc,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
cb);
migraphx::shape sx{migraphx::shape::float_type, {2, 2, 4, 8}};
migraphx::shape sy{migraphx::shape::float_type, {2, 2, 8, 6}};
migraphx::shape sc{migraphx::shape::bool_type};
auto cond = mm->add_parameter("cond", sc);
auto a = mm->add_parameter("a", sx);
auto b = mm->add_parameter("b", sy);
auto qdot = mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qa, qb);
auto fdot = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
qdot);
std::vector<float> v_alpha(fdot->get_shape().elements(), 200.0f);
auto new_alpha = mm->add_literal(migraphx::literal(fdot->get_shape(), v_alpha));
auto alpha_ab = mm->add_instruction(migraphx::make_op("mul"), new_alpha, fdot);
std::vector<float> v_beta(pc->get_shape().elements(), 1.5f);
auto beta = mm->add_literal(migraphx::literal(pc->get_shape(), v_beta));
auto beta_c = mm->add_instruction(migraphx::make_op("mul"), beta, pc);
mm->add_instruction(migraphx::make_op("add"), alpha_ab, beta_c);
migraphx::shape sd{migraphx::shape::float_type, {2, 2, 4, 6}};
migraphx::shape sw{migraphx::shape::float_type, {2, 2, 1, 1}};
auto x = mm->add_parameter("x", sd);
auto w = mm->add_parameter("w", sw);
return p;
};
auto* then_mod = p.create_module("If_6_if");
auto out1 = then_mod->add_instruction(migraphx::make_op("dot"), a, b);
then_mod->add_return({out1});
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{
{0.1f, 0.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
migraphx::quantize_int8_impl(p, quant_params, {"dot"});
migraphx::run_passes(*p.get_main_module(), {migraphx::dead_code_elimination{}});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
}
TEST_CASE(dot_double_2args)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::double_type, {2, 16}};
migraphx::shape sb{migraphx::shape::double_type, {16, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto* else_mod = p.create_module("If_6_else");
auto out2 = else_mod->add_instruction(migraphx::make_op("convolution"), x, w);
else_mod->add_return({out2});
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 2.0f}, {"beta", 1.5f}}), pa, pb);
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
return p;
};
auto create_int8_quantized_prog = [] {
auto create_program_op = [&] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::double_type, {2, 16}};
migraphx::shape sb{migraphx::shape::double_type, {16, 8}};
migraphx::shape sc{migraphx::shape::double_type, {2, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
// quantize parameter a to int8 type, multiply the scale
std::vector<float> vfa(sa.elements(), 0.1f);
auto fpa = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
pa);
auto fa = mm->add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vfa));
auto ma = mm->add_instruction(migraphx::make_op("mul"), fa, fpa);
auto ra = mm->add_instruction(migraphx::make_op("round"), ma);
auto ca = create_clip_op(p, 127.0f, -128.0f, ra);
auto qa = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
ca);
// quantize parameter b to int8 type
auto insert_loc = std::next(pb);
auto fpb = mm->insert_instruction(
insert_loc,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
pb);
std::vector<float> vfb(sb.elements(), 0.1f);
auto fb = mm->add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb));
auto mb = mm->insert_instruction(insert_loc, migraphx::make_op("mul"), fb, fpb);
auto rb = mm->insert_instruction(insert_loc, migraphx::make_op("round"), mb);
auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
auto qb = mm->insert_instruction(
insert_loc,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
cb);
auto qdot = mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qa, qb);
auto fdot = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
qdot);
std::vector<float> v_alpha(fdot->get_shape().elements(), 200.0f);
auto new_alpha = mm->add_literal(migraphx::literal(fdot->get_shape(), v_alpha));
auto alpha_ab = mm->add_instruction(migraphx::make_op("mul"), new_alpha, fdot);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::double_type)}}),
alpha_ab);
migraphx::shape sx{migraphx::shape::float_type, {2, 2, 4, 8}};
migraphx::shape sy{migraphx::shape::float_type, {2, 2, 8, 6}};
migraphx::shape sc{migraphx::shape::bool_type};
auto cond = mm->add_parameter("cond", sc);
auto a = mm->add_parameter("a", sx);
auto b = mm->add_parameter("b", sy);
migraphx::shape sd{migraphx::shape::float_type, {2, 2, 4, 6}};
migraphx::shape sw{migraphx::shape::float_type, {2, 2, 1, 1}};
auto x = mm->add_parameter("x", sd);
auto w = mm->add_parameter("w", sw);
auto* then_mod = p.create_module("If_6_if");
auto ca = then_mod->add_instruction(migraphx::make_op("capture", {{"ins_index", 2}}), a);
auto cb = then_mod->add_instruction(migraphx::make_op("capture", {{"ins_index", 3}}), b);
auto out1 = then_mod->add_instruction(migraphx::make_op("dot"), ca, cb);
then_mod->add_return({out1});
auto* else_mod = p.create_module("If_6_else");
auto cx = else_mod->add_instruction(migraphx::make_op("capture", {{"ins_index", 0}}), x);
auto cw = else_mod->add_instruction(migraphx::make_op("capture", {{"ins_index", 1}}), w);
auto out2 = else_mod->add_instruction(migraphx::make_op("convolution"), cx, cw);
else_mod->add_return({out2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
return p;
};
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8_impl(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog();
{
auto p = create_program();
auto op_capture_p = create_program_op();
migraphx::target t = migraphx::ref::target{};
std::size_t param_index = 0;
migraphx::run_passes(
p, {migraphx::capture_arguments_pass{{"dot", "convolution"}, {}, &param_index}});
EXPECT(p == qp);
EXPECT(p == op_capture_p);
}
}
TEST_CASE(dot_large_alpha_beta_float)
TEST_CASE(dot_float)
{
auto create_program = [] {
migraphx::program p;
......@@ -534,8 +558,9 @@ TEST_CASE(dot_large_alpha_beta_float)
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 20.0f}, {"beta", 50.5f}}), pa, pb, pc);
auto r = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), pa, pb, pc);
mm->add_return({r});
return p;
};
......@@ -546,153 +571,103 @@ TEST_CASE(dot_large_alpha_beta_float)
migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
// quantize parameter a to int8 type, multiply the scale
std::vector<float> vfa(sa.elements(), 0.1f);
auto fa = mm->add_literal(migraphx::literal(sa, vfa));
auto ma = mm->add_instruction(migraphx::make_op("mul"), fa, pa);
// add the shift
std::vector<float> vsa(sa.elements(), 1.0f);
auto sfta = mm->add_literal(migraphx::literal(sa, vsa));
auto msa = mm->add_instruction(migraphx::make_op("add"), sfta, ma);
auto ra = mm->add_instruction(migraphx::make_op("round"), msa);
auto ca = create_clip_op(p, 127.0f, -128.0f, ra);
auto qa = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
ca);
// quantize parameter b to int8 type
auto insert_loc = std::next(pb);
std::vector<float> vfb(sb.elements(), 0.1f);
auto fb = mm->add_literal(migraphx::literal(sb, vfb));
auto mb = mm->insert_instruction(insert_loc, migraphx::make_op("mul"), fb, pb);
auto rb = mm->insert_instruction(insert_loc, migraphx::make_op("round"), mb);
auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
auto qb = mm->insert_instruction(
insert_loc,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
cb);
// quantize parameter c to int32 type
auto qc = mm->insert_instruction(
std::next(pc),
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
pc);
auto qdot = mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 2000}, {"beta", 51}}), qa, qb, qc);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
qdot);
return p;
};
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{
{0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
migraphx::quantize_int8_impl(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
}
TEST_CASE(dot_large_alpha_beta_int32)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::int32_type, {2, 16}};
migraphx::shape sb{migraphx::shape::int32_type, {16, 8}};
migraphx::shape sc{migraphx::shape::int32_type, {2, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 20.0f}, {"beta", 50.0f}}), pa, pb, pc);
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
auto zp_a = mm->add_literal(static_cast<int8_t>(0));
auto scale_a = mm->add_literal(10.0f);
scale_a = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_a);
zp_a = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}),
zp_a);
auto qa = mm->add_instruction(migraphx::make_op("quantizelinear"), pa, scale_a, zp_a);
auto dqa = mm->add_instruction(migraphx::make_op("dequantizelinear"), qa, scale_a, zp_a);
auto zp_b = mm->add_literal(static_cast<int8_t>(0));
auto scale_b = mm->add_literal(10.0f);
scale_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b);
zp_b = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}),
zp_b);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b);
auto zp_c = mm->add_literal(static_cast<int8_t>(100));
auto scale_c = mm->add_literal(10.0f);
scale_c = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sc.lens()}}), scale_c);
zp_c = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sc.lens()}}),
zp_c);
auto qc = mm->add_instruction(migraphx::make_op("quantizelinear"), pc, scale_c, zp_c);
auto dqc = mm->add_instruction(migraphx::make_op("dequantizelinear"), qc, scale_c, zp_c);
auto r = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), dqa, dqb, dqc);
mm->add_return({r});
return p;
};
auto create_int8_quantized_prog = [] {
auto create_int8_optimized_prog = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::int32_type, {2, 16}};
migraphx::shape sb{migraphx::shape::int32_type, {16, 8}};
migraphx::shape sc{migraphx::shape::int32_type, {2, 8}};
migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
// quantize parameter a to int8 type, multiply the scale
std::vector<float> vfa(sa.elements(), 0.1f);
auto fa = mm->add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vfa));
auto conv_a = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
pa);
auto ma = mm->add_instruction(migraphx::make_op("mul"), fa, conv_a);
// add the shift
std::vector<float> vsa(sa.elements(), 1.0f);
auto sfta =
mm->add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vsa));
auto msa = mm->add_instruction(migraphx::make_op("add"), sfta, ma);
auto ra = mm->add_instruction(migraphx::make_op("round"), msa);
auto ca = create_clip_op(p, 127.0f, -128.0f, ra);
auto qa = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
ca);
// quantize parameter b to int8 type
auto insert_loc = std::next(pb);
std::vector<float> vfb(sb.elements(), 0.1f);
auto fb = mm->add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb));
auto conv_b = mm->insert_instruction(
insert_loc,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
pb);
auto mb = mm->insert_instruction(insert_loc, migraphx::make_op("mul"), fb, conv_b);
auto rb = mm->insert_instruction(insert_loc, migraphx::make_op("round"), mb);
auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
auto qb = mm->insert_instruction(
insert_loc,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
cb);
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 2000}, {"beta", 50}}), qa, qb, pc);
mm->add_parameter("c", sc);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(10.0f);
auto scale_a = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale);
auto zp_a =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp);
auto quant_a = mm->add_instruction(migraphx::make_op("quantizelinear"), pa, scale_a, zp_a);
auto scale_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale);
auto zp_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp);
auto quant_b = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto quant = mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), quant_a, quant_b);
std::vector<float> vec(sc.elements(), 100.0f);
auto dc = mm->add_literal(100.0f);
auto mdc =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sc.lens()}}), dc);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, mdc);
mm->add_return({r});
return p;
};
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{
{0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
migraphx::quantize_int8_impl(p, quant_params, {"dot"});
const std::vector<std::pair<float, float>> quant_params = {
{0.1f, 0.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
auto p = create_program();
std::size_t param_index = 0;
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}});
migraphx::run_passes(
p,
{migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
optimize_prog_int8(p);
auto op = create_int8_optimized_prog();
EXPECT(p == op);
}
TEST_CASE(dot_int32_one_arg)
TEST_CASE(dot_double_2args)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {16, 16}};
auto pa = mm->add_parameter("a", s);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 20.0f}, {"beta", 50.0f}}), pa, pa);
migraphx::shape sa{migraphx::shape::double_type, {2, 16}};
migraphx::shape sb{migraphx::shape::double_type, {16, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto r = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), pa, pb);
mm->add_return({r});
return p;
};
......@@ -700,177 +675,87 @@ TEST_CASE(dot_int32_one_arg)
auto create_int8_quantized_prog = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {16, 16}};
auto pa = mm->add_parameter("a", s);
// add the shift
auto fpa = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
pa);
std::vector<float> vsa(s.elements(), 1.0f);
auto sfta =
mm->add_literal(migraphx::literal({migraphx::shape::float_type, s.lens()}, vsa));
auto msa = mm->add_instruction(migraphx::make_op("add"), sfta, fpa);
auto ra = mm->add_instruction(migraphx::make_op("round"), msa);
auto ca = create_clip_op(p, 127.0f, -128.0f, ra);
auto qa = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
ca);
auto q_dot = mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qa, qa);
auto f_dot = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
q_dot);
std::vector<float> v_alpha(f_dot->get_shape().elements(), 20.0f);
auto new_alpha = mm->add_literal(migraphx::literal{f_dot->get_shape(), v_alpha});
auto alpha_ab = mm->add_instruction(migraphx::make_op("mul"), new_alpha, f_dot);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
alpha_ab);
return p;
};
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{1.0f, 1.0f}};
migraphx::quantize_int8_impl(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
}
TEST_CASE(dot_int32)
{
auto create_program = [](bool add_return = false) {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::int32_type, {2, 16}};
migraphx::shape sb{migraphx::shape::int32_type, {16, 8}};
migraphx::shape sc{migraphx::shape::int32_type, {2, 8}};
migraphx::shape sa{migraphx::shape::double_type, {2, 16}};
migraphx::shape sb{migraphx::shape::double_type, {16, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
auto res = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 2.0f}, {"beta", 5.5f}}), pa, pb, pc);
if(add_return)
{
mm->add_return({res});
}
auto zp_a = mm->add_literal(static_cast<int8_t>(0));
auto scale_a = mm->add_literal(10.0);
scale_a = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_a);
zp_a = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}),
zp_a);
auto qa = mm->add_instruction(migraphx::make_op("quantizelinear"), pa, scale_a, zp_a);
auto dqa = mm->add_instruction(migraphx::make_op("dequantizelinear"), qa, scale_a, zp_a);
auto zp_b = mm->add_literal(static_cast<int8_t>(0));
auto scale_b = mm->add_literal(5.0);
scale_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b);
zp_b = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}),
zp_b);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b);
auto r = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), dqa, dqb);
mm->add_return({r});
return p;
};
auto create_int8_quantized_prog = [](bool add_return = false) {
auto create_int8_optimized_prog = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::int32_type, {2, 16}};
migraphx::shape sb{migraphx::shape::int32_type, {16, 8}};
migraphx::shape sc{migraphx::shape::int32_type, {2, 8}};
migraphx::shape sa{migraphx::shape::double_type, {2, 16}};
migraphx::shape sb{migraphx::shape::double_type, {16, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
// quantize parameter a to int8 type, multiply the scale
std::vector<float> vfa(sa.elements(), 0.1f);
auto fa = mm->add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vfa));
auto conv_a = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
pa);
auto ma = mm->add_instruction(migraphx::make_op("mul"), fa, conv_a);
// add the shift
std::vector<float> vsa(sa.elements(), 1.0f);
auto sfta =
mm->add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vsa));
auto msa = mm->add_instruction(migraphx::make_op("add"), sfta, ma);
auto ra = mm->add_instruction(migraphx::make_op("round"), msa);
auto ca = create_clip_op(p, 127.0f, -128.0f, ra);
auto qa = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
ca);
// quantize parameter b to int8 type
auto insert_loc = std::next(pb);
std::vector<float> vfb(sb.elements(), 0.1f);
auto fb = mm->add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb));
auto conv_b = mm->insert_instruction(
insert_loc,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
pb);
auto mb = mm->insert_instruction(insert_loc, migraphx::make_op("mul"), fb, conv_b);
auto rb = mm->insert_instruction(insert_loc, migraphx::make_op("round"), mb);
auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
auto qb = mm->insert_instruction(
insert_loc,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
cb);
auto scale_a = mm->add_literal(10.0);
auto zp = mm->add_literal(static_cast<int8_t>(0));
scale_a = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_a);
auto zp_a =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp);
auto qa = mm->add_instruction(migraphx::make_op("quantizelinear"), pa, scale_a, zp_a);
auto scale_b = mm->add_literal(5.0);
scale_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b);
auto zp_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto qdot = mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qa, qb);
auto fr = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
qdot);
std::vector<float> v_alpha(fr->get_shape().elements(), 20.0f);
auto new_alpha = mm->add_literal(migraphx::literal(fr->get_shape(), v_alpha));
auto alpha_ab = mm->add_instruction(migraphx::make_op("mul"), new_alpha, fr);
auto fc = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
pc);
std::vector<float> v_beta(fc->get_shape().elements(), 5.5f);
auto beta = mm->add_literal(migraphx::literal(fc->get_shape(), v_beta));
auto beta_c = mm->add_instruction(migraphx::make_op("mul"), beta, fc);
auto f_res = mm->add_instruction(migraphx::make_op("add"), alpha_ab, beta_c);
auto res = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
f_res);
if(add_return)
{
mm->add_return({res});
}
auto scale = mm->add_literal(50.0);
scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, scale);
mm->add_return({r});
return p;
};
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{
{0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
migraphx::quantize_int8_impl(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
auto p_ret = create_program(true);
migraphx::quantize_int8_impl(p_ret, quant_params, {"dot"});
auto qp_ret = create_int8_quantized_prog(true);
EXPECT(p_ret == qp_ret);
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.2f, 0.0f}};
std::size_t param_index = 0;
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}});
migraphx::run_passes(
p,
{migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}});
EXPECT(p == create_int8_quantized_prog());
optimize_prog_int8(p);
EXPECT(p == create_int8_optimized_prog());
}
TEST_CASE(dot_float_convert)
TEST_CASE(dot_half_1arg)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::int8_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto fpa = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
pa);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 2.0f}, {"beta", 5.5f}}), fpa, pb);
migraphx::shape s{migraphx::shape::half_type, {9, 9}};
auto x = mm->add_parameter("x", s);
auto r =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), x, x);
mm->add_return({r});
return p;
};
......@@ -878,44 +763,67 @@ TEST_CASE(dot_float_convert)
auto create_int8_quantized_prog = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::int8_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
// quantize parameter b to int8 type
auto insert_loc = std::next(pb);
std::vector<float> vfb(sb.elements(), 0.1f);
auto fb = mm->add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb));
auto mb = mm->insert_instruction(insert_loc, migraphx::make_op("mul"), fb, pb);
auto rb = mm->insert_instruction(insert_loc, migraphx::make_op("round"), mb);
auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
auto qb = mm->insert_instruction(
insert_loc,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
cb);
migraphx::shape sa{migraphx::shape::half_type, {9, 9}};
auto x = mm->add_parameter("x", sa);
auto zp_a = mm->add_literal(static_cast<int8_t>(0));
auto scale_a = mm->add_literal(migraphx::literal({sa.type()}, {10.0}));
scale_a = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_a);
zp_a = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}),
zp_a);
auto qa = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale_a, zp_a);
auto dqa = mm->add_instruction(migraphx::make_op("dequantizelinear"), qa, scale_a, zp_a);
auto zp_b = mm->add_literal(static_cast<int8_t>(0));
auto scale_b = mm->add_literal(migraphx::literal({sa.type()}, {10.0}));
scale_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_b);
zp_b = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}),
zp_b);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale_b, zp_b);
auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b);
auto r = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), dqa, dqb);
mm->add_return({r});
return p;
};
auto create_int8_optimized_prog = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::half_type, {9, 9}};
auto x = mm->add_parameter("x", sa);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(migraphx::literal({sa.type()}, {10.0}));
scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}),
scale);
zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp);
auto qx = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp);
auto qdot = mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), pa, qb);
auto fr = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
qdot);
std::vector<float> v_alpha(fr->get_shape().elements(), 10.0f);
auto new_alpha = mm->add_literal(migraphx::literal(fr->get_shape(), v_alpha));
mm->add_instruction(migraphx::make_op("mul"), new_alpha, fr);
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qx, qx);
auto dq_scale = mm->add_literal(migraphx::literal({sa.type()}, {100.0}));
dq_scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}),
dq_scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, dq_scale);
mm->add_return({r});
return p;
};
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 1.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8_impl(p, quant_params, {"dot"});
migraphx::run_passes(*p.get_main_module(), {migraphx::dead_code_elimination{}});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
std::size_t param_index = 0;
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}});
migraphx::run_passes(
p,
{migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}});
EXPECT(p == create_int8_quantized_prog());
optimize_prog_int8(p);
EXPECT(p == create_int8_optimized_prog());
}
TEST_CASE(conv_float)
......@@ -927,7 +835,8 @@ TEST_CASE(conv_float)
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto r = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
mm->add_return({r});
return p;
};
......@@ -939,119 +848,61 @@ TEST_CASE(conv_float)
migraphx::shape sw{migraphx::shape::float_type, {4, 3, 3, 3}};
auto px = mm->add_parameter("x", sx);
auto pw = mm->add_parameter("w", sw);
// quantize parameter a to int8 type, multiply the scale
std::vector<float> vfx(sx.elements(), 0.1f);
auto fx = mm->add_literal(migraphx::literal(sx, vfx));
auto mx = mm->add_instruction(migraphx::make_op("mul"), fx, px);
auto rx = mm->add_instruction(migraphx::make_op("round"), mx);
auto cx = create_clip_op(p, 127.0f, -128.0f, rx);
auto qx = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
cx);
// quantize parameter b to int8 type
auto insert_loc = std::next(pw);
std::vector<float> vfw(sw.elements(), 0.1f);
auto fw = mm->add_literal(migraphx::literal(sw, vfw));
auto mw = mm->insert_instruction(insert_loc, migraphx::make_op("mul"), fw, pw);
auto rw = mm->insert_instruction(insert_loc, migraphx::make_op("round"), mw);
auto cw = create_clip_op(insert_loc, p, 127.0f, -128.0f, rw);
auto qw = mm->insert_instruction(
insert_loc,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
cw);
auto q_conv = mm->add_instruction(migraphx::make_op("quant_convolution"), qx, qw);
auto f_conv = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
q_conv);
std::vector<float> v_adj(f_conv->get_shape().elements(), 100.0f);
auto adj = mm->add_literal(migraphx::literal(f_conv->get_shape(), v_adj));
mm->add_instruction(migraphx::make_op("mul"), adj, f_conv);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(10.0f);
scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}),
scale);
zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), zp);
auto quant_x = mm->add_instruction(migraphx::make_op("quantizelinear"), px, scale, zp);
auto quant_w = mm->add_instruction(migraphx::make_op("quantizelinear"), pw, scale, zp);
auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), quant_x, quant_w);
migraphx::shape sc{migraphx::shape::float_type, {4, 4, 1, 1}};
std::vector<float> vec(sc.elements(), 100.0f);
migraphx::shape s_scale{migraphx::shape::float_type, sc.lens()};
auto d_scale = mm->add_literal(100.0f);
d_scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {4, 4, 1, 1}}}), d_scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, d_scale);
mm->add_return({r});
return p;
};
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8_impl(p, quant_params, {"convolution"});
std::size_t param_index = 0;
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"convolution"}, {}, &param_index}});
migraphx::run_passes(p, {migraphx::quantize_int8_pass{{"convolution"}, quant_params}});
optimize_prog_int8(p);
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
}
TEST_CASE(conv_int32)
TEST_CASE(conv_float_throw)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {4, 3, 3, 3}});
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::int32_type, {4, 3, 3, 3}});
mm->add_instruction(migraphx::make_op("convolution"), input, weights);
return p;
};
auto create_int8_quantized_prog = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sx{migraphx::shape::int32_type, {4, 3, 3, 3}};
migraphx::shape sw{migraphx::shape::int32_type, {4, 3, 3, 3}};
auto px = mm->add_parameter("x", sx);
auto pw = mm->add_parameter("w", sw);
// quantize parameter a to int8 type, multiply the scale
auto fpx = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
px);
std::vector<float> vfx(sx.elements(), 0.1f);
auto fx = mm->add_literal(migraphx::literal(fpx->get_shape(), vfx));
auto mx = mm->add_instruction(migraphx::make_op("mul"), fx, fpx);
auto rx = mm->add_instruction(migraphx::make_op("round"), mx);
auto cx = create_clip_op(p, 127.0f, -128.0f, rx);
auto qx = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
cx);
// quantize parameter b to int8 type
auto insert_loc = std::next(pw);
auto fpw = mm->insert_instruction(
insert_loc,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
pw);
std::vector<float> vfw(sw.elements(), 0.1f);
auto fw = mm->add_literal(migraphx::literal(fpw->get_shape(), vfw));
auto mw = mm->insert_instruction(insert_loc, migraphx::make_op("mul"), fw, fpw);
auto rw = mm->insert_instruction(insert_loc, migraphx::make_op("round"), mw);
auto cw = create_clip_op(insert_loc, p, 127.0f, -128.0f, rw);
auto qw = mm->insert_instruction(
insert_loc,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
cw);
auto q_conv = mm->add_instruction(migraphx::make_op("quant_convolution"), qx, qw);
std::vector<float> v_adj(q_conv->get_shape().elements(), 100.0f);
auto adj = mm->add_literal(migraphx::literal(q_conv->get_shape(), v_adj));
mm->add_instruction(migraphx::make_op("mul"), q_conv, adj);
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto r = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
mm->add_return({r});
return p;
};
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8_impl(p, quant_params, {"convolution"});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
test::throws([&] {
migraphx::run_passes(p, {migraphx::quantize_int8_pass{{"add"}, quant_params}});
});
}
TEST_CASE(conv_half)
......@@ -1063,7 +914,8 @@ TEST_CASE(conv_half)
mm->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}});
mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto r = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
mm->add_return({r});
return p;
};
......@@ -1075,63 +927,43 @@ TEST_CASE(conv_half)
migraphx::shape sw{migraphx::shape::half_type, {4, 3, 3, 3}};
auto px = mm->add_parameter("x", sx);
auto pw = mm->add_parameter("w", sw);
// quantize parameter a to int8 type, multiply the scale
auto fpx = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
px);
std::vector<float> vfx(sx.elements(), 0.1f);
auto fx = mm->add_literal(migraphx::literal(fpx->get_shape(), vfx));
auto mx = mm->add_instruction(migraphx::make_op("mul"), fx, fpx);
auto rx = mm->add_instruction(migraphx::make_op("round"), mx);
auto cx = create_clip_op(p, 127.0f, -128.0f, rx);
auto qx = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
cx);
// quantize parameter b to int8 type
auto insert_loc = std::next(pw);
auto fpw = mm->insert_instruction(
insert_loc,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
pw);
std::vector<float> vfw(sw.elements(), 0.1f);
auto fw = mm->add_literal(migraphx::literal(fpw->get_shape(), vfw));
auto mw = mm->insert_instruction(insert_loc, migraphx::make_op("mul"), fw, fpw);
auto rw = mm->insert_instruction(insert_loc, migraphx::make_op("round"), mw);
auto cw = create_clip_op(insert_loc, p, 127.0f, -128.0f, rw);
auto qw = mm->insert_instruction(
insert_loc,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
cw);
auto q_conv = mm->add_instruction(migraphx::make_op("quant_convolution"), qx, qw);
auto f_conv = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
q_conv);
std::vector<float> v_adj(f_conv->get_shape().elements(), 100.0f);
auto adj = mm->add_literal(migraphx::literal(f_conv->get_shape(), v_adj));
auto f_res = mm->add_instruction(migraphx::make_op("mul"), adj, f_conv);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
f_res);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(migraphx::literal({sx.type()}, {10.0}));
scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}),
scale);
zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), zp);
auto quant_x = mm->add_instruction(migraphx::make_op("quantizelinear"), px, scale, zp);
auto quant_w = mm->add_instruction(migraphx::make_op("quantizelinear"), pw, scale, zp);
auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), quant_x, quant_w);
auto d_scale = mm->add_literal(migraphx::literal({sx.type()}, {100.0}));
d_scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {4, 4, 1, 1}}}), d_scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, d_scale);
mm->add_return({r});
return p;
};
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8_impl(p, quant_params, {"convolution"});
std::size_t param_index = 0;
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"convolution"}, {}, &param_index}});
migraphx::run_passes(p, {migraphx::quantize_int8_pass{{"convolution"}, quant_params}});
optimize_prog_int8(p);
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
}
template <class T>
auto get_hash(const T& x)
{
return std::hash<T>{}(x);
}
TEST_CASE(target_copy)
{
auto run_prog = [](migraphx::program p,
......@@ -1223,7 +1055,8 @@ TEST_CASE(int8_quantization_dot)
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
mm->add_instruction(migraphx::make_op("dot"), pa, pb, pc);
auto r = mm->add_instruction(migraphx::make_op("dot"), pa, pb, pc);
mm->add_return({r});
return p;
};
......@@ -1232,9 +1065,9 @@ TEST_CASE(int8_quantization_dot)
auto p = create_program();
migraphx::parameter_map m;
migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
m["a"] = migraphx::generate_argument(sa);
m["c"] = migraphx::generate_argument(sc);
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
m["a"] = migraphx::generate_argument(sa, get_hash(std::string("a")));
m["b"] = migraphx::generate_argument(sb, get_hash(std::string("b")));
std::vector<float> quant_result;
migraphx::target ref_t = migraphx::ref::target{};
run_prog(p, ref_t, m, quant_result, true);
......@@ -1272,7 +1105,8 @@ TEST_CASE(int8_quantization_conv)
std::vector<float> v(sx.elements(), 0.5f);
auto input = mm->add_literal(migraphx::literal(sx, v));
auto weights = mm->add_literal(migraphx::literal(sw, v));
mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto r = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
mm->add_return({r});
return p;
};
......@@ -1290,4 +1124,156 @@ TEST_CASE(int8_quantization_conv)
}
}
TEST_CASE(int8_subgraph)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sx{migraphx::shape::float_type, {2, 2, 4, 8}};
migraphx::shape sy{migraphx::shape::float_type, {2, 2, 8, 6}};
migraphx::shape sc{migraphx::shape::bool_type};
auto cond = mm->add_parameter("cond", sc);
auto a = mm->add_parameter("a", sx);
auto b = mm->add_parameter("b", sy);
migraphx::shape sd{migraphx::shape::float_type, {2, 2, 4, 6}};
migraphx::shape sw{migraphx::shape::float_type, {2, 2, 1, 1}};
auto x = mm->add_parameter("x", sd);
auto w = mm->add_parameter("w", sw);
auto* then_mod = p.create_module("If_6_if");
auto out1 = then_mod->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), a, b);
then_mod->add_return({out1});
auto* else_mod = p.create_module("If_6_else");
auto out2 = else_mod->add_instruction(migraphx::make_op("convolution"), x, w);
else_mod->add_return({out2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
return p;
};
auto create_int8_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sx{migraphx::shape::float_type, {2, 2, 4, 8}};
migraphx::shape sy{migraphx::shape::float_type, {2, 2, 8, 6}};
migraphx::shape sout{migraphx::shape::float_type, {2, 2, 4, 6}};
migraphx::shape sc{migraphx::shape::bool_type};
auto cond = mm->add_parameter("cond", sc);
auto a = mm->add_parameter("a", sx);
auto b = mm->add_parameter("b", sy);
// then submod
auto* then_mod = p.create_module("If_6_if");
auto zp1 = then_mod->add_literal(static_cast<int8_t>(0));
auto s1 = then_mod->add_literal(10.0f);
auto sa = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), s1);
auto zpa = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), zp1);
auto qa = then_mod->add_instruction(migraphx::make_op("quantizelinear"), a, sa, zpa);
auto sb = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), s1);
auto zpb = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), zp1);
auto qb = then_mod->add_instruction(migraphx::make_op("quantizelinear"), b, sb, zpb);
auto qdot =
then_mod->add_instruction(migraphx::make_op("quant_dot", {{"beta", 0}}), qa, qb);
auto so = then_mod->add_literal(100.0f);
so = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sout.lens()}}), so);
auto r = then_mod->add_instruction(migraphx::make_op("dequantizelinear"), qdot, so);
then_mod->add_return({r});
migraphx::shape sd{migraphx::shape::float_type, {2, 2, 4, 6}};
migraphx::shape sw{migraphx::shape::float_type, {2, 2, 1, 1}};
auto x = mm->add_parameter("x", sd);
auto w = mm->add_parameter("w", sw);
// else submod
auto* else_mod = p.create_module("If_6_else");
auto sax = else_mod->add_literal(2.0f);
auto zp = else_mod->add_literal(static_cast<int8_t>(0));
sax = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), sax);
auto zpx = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), zp);
auto qx = else_mod->add_instruction(migraphx::make_op("quantizelinear"), x, sax, zpx);
auto ssw = else_mod->add_literal(1.66667f);
ssw = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), ssw);
auto zpw = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), zp);
auto qw = else_mod->add_instruction(migraphx::make_op("quantizelinear"), w, ssw, zpw);
auto qconv = else_mod->add_instruction(migraphx::make_op("quant_convolution"), qx, qw);
auto so1 = else_mod->add_literal(3.33333f);
so1 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sout.lens()}}), so1);
auto r1 = else_mod->add_instruction(migraphx::make_op("dequantizelinear"), qconv, so1);
else_mod->add_return({r1});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
return p;
};
auto p1 = create_program();
const std::vector<std::pair<float, float>>& quant_params{
{0.5f, 0.0f}, {0.6f, 0.0f}, {0.1f, 0.0f}, {0.1f, 0.0f}};
std::size_t param_index = 0;
migraphx::run_passes(
p1, {migraphx::capture_arguments_pass{{"convolution", "dot"}, {}, &param_index}});
migraphx::run_passes(p1, {migraphx::quantize_int8_pass{{"convolution", "dot"}, quant_params}});
optimize_prog_int8(p1);
auto p2 = create_int8_program();
EXPECT(p1 == p2);
}
TEST_CASE(test_op_capture)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {3, 6}};
std::vector<float> d1(s1.elements());
std::vector<float> d2(s2.elements());
std::iota(d1.begin(), d1.end(), 0.0f);
std::iota(d2.begin(), d2.end(), 0.0f);
auto p1 = mm->add_literal(s1, d1);
auto p2 = mm->add_literal(s1, d1);
auto pb = mm->add_literal(s2, d2);
auto pc = mm->add_literal(s2, d2);
auto pa = mm->add_instruction(migraphx::make_op("add"), p1, p2);
auto ps = mm->add_instruction(migraphx::make_op("dot"), pa, pb, pc);
mm->add_instruction(migraphx::make_op("dot"), pa, ps);
auto calc = [](std::size_t, const std::vector<migraphx::argument>&) {};
migraphx::program capture_p = p;
migraphx::target t = migraphx::ref::target{};
std::size_t param_index = 0;
migraphx::run_passes(capture_p,
{migraphx::capture_arguments_pass{{"dot"}, calc, &param_index}});
p.compile(migraphx::ref::target{});
capture_p.compile(migraphx::ref::target{});
auto cap_res = capture_p.eval({}).back();
auto res = p.eval({}).back();
std::vector<float> vec;
std::vector<float> cap_vec;
cap_res.visit([&](auto output) { cap_vec.assign(output.begin(), output.end()); });
res.visit([&](auto output) { vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(vec, cap_vec));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -2741,43 +2741,6 @@ TEST_CASE(not_test)
}
}
TEST_CASE(op_capture)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {3, 6}};
std::vector<float> d1(s1.elements());
std::vector<float> d2(s2.elements());
std::iota(d1.begin(), d1.end(), 0.0f);
std::iota(d2.begin(), d2.end(), 0.0f);
auto p1 = mm->add_literal(s1, d1);
auto p2 = mm->add_literal(s1, d1);
auto pb = mm->add_literal(s2, d2);
auto pc = mm->add_literal(s2, d2);
auto pa = mm->add_instruction(migraphx::make_op("add"), p1, p2);
auto ps = mm->add_instruction(migraphx::make_op("dot"), pa, pb, pc);
mm->add_instruction(migraphx::make_op("dot"), pa, ps);
migraphx::program capture_p = p;
migraphx::target t = migraphx::ref::target{};
migraphx::capture_arguments(capture_p, t, {"dot"});
p.compile(migraphx::ref::target{});
capture_p.compile(migraphx::ref::target{});
auto cap_res = capture_p.eval({}).back();
auto res = p.eval({}).back();
std::vector<float> vec;
std::vector<float> cap_vec;
cap_res.visit([&](auto output) { cap_vec.assign(output.begin(), output.end()); });
res.visit([&](auto output) { vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(vec, cap_vec));
}
TEST_CASE(pad_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment