Commit 7dc6e3ae authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mi100_opts

parents f94d77fc a275f590
......@@ -325,12 +325,14 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
unsigned int default_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
bool skip_unknown_operators,
bool print_program_on_error) {
bool print_program_on_error,
int64_t max_loop_iterations) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.map_input_dims = map_input_dims;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
options.max_loop_iterations = max_loop_iterations;
return migraphx::parse_onnx(filename, options);
},
"Parse onnx file",
......@@ -338,7 +340,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false);
py::arg("print_program_on_error") = false,
py::arg("max_loop_iterations") = 10);
m.def("parse_onnx_buffer",
[](const std::string& onnx_buffer,
......
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/quantize_int8.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/op/convert.hpp>
#include <migraphx/op/clip.hpp>
#include <migraphx/op/round.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/target.hpp>
#include <utility>
#include <set>
#include <iomanip>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <fstream>
#include <algorithm>
#include <migraphx/pass_manager.hpp>
#include <set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
instruction_ref insert_quant_ins(module& modl,
instruction_ref& ins,
shape::type_t type,
std::unordered_map<instruction_ref, instruction_ref>& map_ins,
float scale = 1.0f,
float shift = 0.0f)
{
if(map_ins.count(ins) > 0)
{
return map_ins[ins];
}
if(ins->name() == "undefined")
{
return ins;
}
assert(ins->get_shape().type() == shape::float_type or
ins->get_shape().type() == shape::double_type or
ins->get_shape().type() == shape::int32_type or
ins->get_shape().type() == shape::half_type);
instruction_ref quant_ins{};
auto insert_loc = std::next(ins);
if(type == shape::int8_type)
{
auto scaled_ins = ins;
if(scale != 1.0f)
{
auto float_ins = scaled_ins;
if(scaled_ins->get_shape().type() != shape::float_type)
{
float_ins = modl.insert_instruction(
insert_loc,
make_op("convert", {{"target_type", to_value(shape::float_type)}}),
scaled_ins);
}
std::vector<float> vec_scale(scaled_ins->get_shape().elements(), scale);
auto l_scale = modl.add_literal(literal(float_ins->get_shape(), vec_scale));
scaled_ins = modl.insert_instruction(insert_loc, make_op("mul"), l_scale, float_ins);
}
auto shifted_ins = scaled_ins;
if(shift != 0.0f)
{
auto float_ins = shifted_ins;
if(shifted_ins->get_shape().type() != shape::float_type)
{
float_ins = modl.insert_instruction(
insert_loc,
make_op("convert", {{"target_type", to_value(shape::float_type)}}),
shifted_ins);
}
std::vector<float> vec_shift(shifted_ins->get_shape().elements(), shift);
auto l_shift = modl.add_literal(literal(float_ins->get_shape(), vec_shift));
shifted_ins = modl.insert_instruction(insert_loc, make_op("add"), l_shift, float_ins);
}
auto rounded_ins = modl.insert_instruction(insert_loc, make_op("round"), shifted_ins);
auto rounded_lens = rounded_ins->get_shape().lens();
auto max_clip = modl.add_literal(127.0f);
auto min_clip = modl.add_literal(-128.0f);
max_clip = modl.insert_instruction(
insert_loc, make_op("multibroadcast", {{"output_lens", rounded_lens}}), max_clip);
min_clip = modl.insert_instruction(
insert_loc, make_op("multibroadcast", {{"output_lens", rounded_lens}}), min_clip);
auto clipped_ins =
modl.insert_instruction(insert_loc, make_op("clip"), rounded_ins, min_clip, max_clip);
quant_ins = modl.insert_instruction(
insert_loc, make_op("convert", {{"target_type", type}}), clipped_ins);
}
else
{
quant_ins =
modl.insert_instruction(insert_loc, make_op("convert", {{"target_type", type}}), ins);
}
map_ins[ins] = quant_ins;
return quant_ins;
}
// This function is to convert any instructions specified in the input
// from double or float to float16 by inserting a convert operator.
// For the conversion, there could be cases of overflowing, but it
......@@ -119,337 +30,14 @@ instruction_ref insert_quant_ins(module& modl,
// truncate of the input to get the fp16.
void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
{
auto* mm = prog.get_main_module();
std::unordered_map<instruction_ref, instruction_ref> map_fp16;
for(auto ins : iterator_for(*mm))
{
if(ins->name() == "@return")
break;
// all indicates every instruction is converted
if((not contains(ins_names, "all")) and (not contains(ins_names, ins->name())))
{
continue;
}
shape::type_t orig_type = ins->get_shape().type();
// process all inputs, if input is a fp32 or fp64, convert it
// to a fp16 by adding a convert operator.
auto inputs = ins->inputs();
std::vector<instruction_ref> converted_inputs;
for(auto input : inputs)
{
auto s = input->get_shape();
if(s.type() == shape::float_type || s.type() == shape::double_type)
{
// if the input is a convert operator, uses its input
// as its current input
instruction_ref input_fp16{};
if(input->name() == "convert" and
input->inputs().front()->get_shape().type() == shape::half_type)
{
input_fp16 = input->inputs().front();
}
else
{
input_fp16 = insert_quant_ins(*mm, input, shape::half_type, map_fp16);
}
converted_inputs.push_back(input_fp16);
}
else
{
converted_inputs.push_back(input);
}
}
// no change for the input, go to the next instruction
if(inputs == converted_inputs)
{
continue;
}
auto op = ins->get_operator();
auto ins_shape = compute_shape(op, converted_inputs);
if(ins_shape.type() != orig_type)
{
// check the dead code case to avoid assert
bool output_empty = ins->outputs().empty();
auto ins_orig_type = mm->insert_instruction(
std::next(ins), make_op("convert", {{"target_type", orig_type}}), ins);
if(!output_empty)
{
mm->replace_instruction(ins, ins_orig_type);
}
}
mm->replace_instruction(ins, op, converted_inputs);
}
}
static void ins_quantize_int8(module& modl,
instruction_ref ins,
std::vector<instruction_ref>& converted_inputs,
const std::vector<std::pair<float, float>>& ins_quant_params)
{
auto orig_type = ins->get_shape().type();
auto inputs = ins->inputs();
if(ins->name() == "dot")
{
auto dot_op = any_cast<op::dot>(ins->get_operator());
float new_alpha = dot_op.alpha / (ins_quant_params[0].first * ins_quant_params[1].first);
float new_beta = dot_op.beta;
// We need additional checking about the quant_alpha value. If
// abs(quant_alpha) > 50 (some tmp value set here), we can convert
// it to an integer as the new_alpha in the quant_dot
float threshold = 50.0f;
if(fabs(new_alpha) >= threshold && fabs(new_beta) >= threshold)
{
int32_t quant_alpha = static_cast<int32_t>(std::round(new_alpha));
int32_t quant_beta = static_cast<int32_t>(std::round(new_beta));
if(shape::int32_type == orig_type)
{
modl.replace_instruction(
ins,
make_op("quant_dot", {{"alpha", quant_alpha}, {"beta", quant_beta}}),
converted_inputs);
}
else
{
auto quant_dot = modl.insert_instruction(
ins,
make_op("quant_dot", {{"alpha", quant_alpha}, {"beta", quant_beta}}),
converted_inputs);
modl.replace_instruction(
ins, make_op("convert", {{"target_type", to_value(orig_type)}}), quant_dot);
}
}
// either alpha or beta cannot be quantized because of too big
// relative rounding error
else
{
if(converted_inputs.size() == 3)
{
converted_inputs.pop_back();
}
auto q_dot = modl.insert_instruction(
ins, make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), converted_inputs);
auto f_dot = modl.insert_instruction(
ins, make_op("convert", {{"target_type", to_value(shape::float_type)}}), q_dot);
auto c_shape = q_dot->get_shape();
std::vector<float> vec_alpha(c_shape.elements(), new_alpha);
auto l_alpha =
modl.add_literal(literal({shape::float_type, c_shape.lens()}, vec_alpha));
if(inputs.size() == 3 and dot_op.beta != 0.0f)
{
auto alpha_ab = modl.insert_instruction(ins, make_op("mul"), l_alpha, f_dot);
std::vector<float> vec_beta(c_shape.elements(), dot_op.beta);
auto l_beta =
modl.add_literal(literal({shape::float_type, c_shape.lens()}, vec_beta));
instruction_ref beta_c{};
if(orig_type != shape::float_type)
{
auto fp32_c = modl.insert_instruction(
ins,
make_op("convert", {{"target_type", to_value(shape::float_type)}}),
inputs.back());
beta_c = modl.insert_instruction(ins, make_op("mul"), l_beta, fp32_c);
}
else
{
beta_c = modl.insert_instruction(ins, make_op("mul"), l_beta, inputs.back());
}
if(orig_type == shape::float_type)
{
modl.replace_instruction(ins, make_op("add"), alpha_ab, beta_c);
}
else
{
auto f_res = modl.insert_instruction(ins, make_op("add"), alpha_ab, beta_c);
modl.replace_instruction(
ins, make_op("convert", {{"target_type", to_value(orig_type)}}), f_res);
}
}
else
{
if(orig_type == shape::float_type)
{
modl.replace_instruction(ins, make_op("mul"), l_alpha, f_dot);
}
else
{
auto alpha_ab = modl.insert_instruction(ins, make_op("mul"), l_alpha, f_dot);
modl.replace_instruction(
ins, make_op("convert", {{"target_type", to_value(orig_type)}}), alpha_ab);
}
}
}
}
else if(ins->name() == "convolution")
{
// Current MIOpen convolution does not support alpha and beta,
// so we need a separate multiply to adjust the output
auto conv_op = any_cast<op::convolution>(ins->get_operator());
auto padding = conv_op.padding;
auto stride = conv_op.stride;
auto dilation = conv_op.dilation;
auto padding_mode = conv_op.padding_mode;
auto group = conv_op.group;
auto adjust_factor = 1.0f / (ins_quant_params[0].first * ins_quant_params[1].first);
auto quant_conv = modl.insert_instruction(
ins,
op::quant_convolution{padding, stride, dilation, padding_mode, group},
converted_inputs);
float threshold = 50.0f;
std::vector<float> vec_factor(quant_conv->get_shape().elements(), adjust_factor);
if(quant_conv->get_shape().type() == orig_type and adjust_factor >= threshold)
{
auto l_factor = modl.add_literal(
literal(quant_conv->get_shape(), vec_factor.begin(), vec_factor.end()));
modl.replace_instruction(ins, make_op("mul"), quant_conv, l_factor);
}
// convert quant_conv output to float type, multiply the factor and
// conver back to original type
else
{
auto float_conv = modl.insert_instruction(
ins,
make_op("convert", {{"target_type", to_value(shape::float_type)}}),
quant_conv);
auto l_factor = modl.add_literal(literal(float_conv->get_shape(), vec_factor));
if(orig_type == shape::float_type)
{
modl.replace_instruction(ins, make_op("mul"), l_factor, float_conv);
}
else
{
auto adjusted_conv =
modl.insert_instruction(ins, make_op("mul"), l_factor, float_conv);
modl.replace_instruction(
ins, make_op("convert", {{"target_type", to_value(orig_type)}}), adjusted_conv);
}
}
}
else
{
MIGRAPHX_THROW("QUANTIZE_INT8: does not support operator " + ins->name());
}
}
// int8 quantization is different from fp16 since int8 can only handle value
// -128 ~ 127. To convert the float or double to int8, we need a scale and
// a shift, then the convert can be done as v_int8 = fp * scale + shift.
// To simplify the changes, we consider shift as 0.0f for now.
void quantize_int8_impl(program& prog,
const std::vector<std::pair<float, float>>& quant_params,
const std::vector<std::string>& ins_names)
{
if(enabled(MIGRAPHX_INT8_QUANTIZATION_PARAMS{}))
{
for(std::size_t i = 0; i < quant_params.size(); ++i)
{
auto param = quant_params.at(i);
std::cout << "ins_index = " << i << ", scale = " << param.first
<< ", shift = " << param.second << std::endl;
}
std::cout << std::endl;
}
// For now, we only support the int8 quantization of gemm and convolution
std::set<std::string> op_names = {"convolution", "dot"};
std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end());
if(!std::includes(
op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
{
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
}
auto* mm = prog.get_main_module();
std::size_t quant_param_index = 0;
std::unordered_map<instruction_ref, instruction_ref> map_quant_ins;
std::unordered_map<instruction_ref, std::size_t> map_ins_index;
for(auto ins : iterator_for(*mm))
{
if(ins->name() == "@return")
break;
if(not contains(ins_names, ins->name()))
{
continue;
}
// for the dot operator, there could be 2 or 3 input arguments
// if the 3rd argument is available, convert it to an int32.
std::vector<instruction_ref> converted_inputs;
// process all inputs, if input is a fp32 or fp64, convert it
// to a int8 type by adding a convert operator and replace
// the operator with the corresponding int8 version
auto inputs = ins->inputs();
std::vector<std::pair<float, float>> ins_quant_params;
for(auto input : inputs)
{
// calculate the index of each instruction to be quantized
std::size_t ins_index =
(map_ins_index.count(input) > 0) ? map_ins_index[input] : quant_param_index++;
map_ins_index[input] = ins_index;
auto param = quant_params[map_ins_index[input]];
ins_quant_params.push_back(param);
// In general, the target_type is int8, but for the dot
// operation, if it has 3 inputs, then the last one should
// be converted to int32_type
shape::type_t quant_type = shape::int8_type;
if((ins->name() == "dot") and (inputs.size() == 3) and (input == inputs.back()))
{
quant_type = shape::int32_type;
}
auto s = input->get_shape();
if((s.type() == shape::float_type or s.type() == shape::double_type or
s.type() == shape::half_type or s.type() == shape::int32_type) and
s.type() != quant_type)
{
// if the input is a convert operator, uses its input
// as its current input
instruction_ref quant_input{};
if(input->name() == "convert" and
input->inputs().front()->get_shape().type() == quant_type)
{
quant_input = input->inputs().front();
// the scale in this case is not used, so tune the scale
// to 1.0f for this parameter
ins_quant_params.back() = std::pair<float, float>(1.0f, 0.0f);
}
else
{
quant_input = insert_quant_ins(
*mm, input, quant_type, map_quant_ins, param.first, param.second);
}
converted_inputs.push_back(quant_input);
}
else
{
converted_inputs.push_back(input);
}
}
// no change for the input, go to the next instruction
if(inputs == converted_inputs)
{
continue;
}
ins_quantize_int8(*mm, ins, converted_inputs, ins_quant_params);
}
if(quant_param_index != quant_params.size())
{
MIGRAPHX_THROW("QUANTIZE_INT8: number of scales does not match");
}
run_passes(prog,
{quantize_fp16_pass{ins_names},
eliminate_common_subexpression{},
dead_code_elimination{},
simplify_reshapes{},
dead_code_elimination{},
simplify_qdq{},
dead_code_elimination{}});
}
void quantize_int8(program& prog,
......@@ -457,87 +45,14 @@ void quantize_int8(program& prog,
const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names)
{
// insert capture operator
auto cap_prog = prog;
auto int8_quant_params = capture_arguments(cap_prog, t, ins_names);
// use the calibration data to compute the quantization scale
cap_prog.compile(t);
// use all calibration data to run the program to calculate the
// quantization scale and shift
for(auto&& arg : calibration)
{
parameter_map m;
for(auto&& x : cap_prog.get_parameter_shapes())
{
if(arg.count(x.first) > 0)
{
assert(x.second == arg.at(x.first).get_shape());
m[x.first] = t.copy_to(arg.at(x.first));
}
else
{
m[x.first] = t.allocate(x.second);
}
}
cap_prog.eval(m);
}
quantize_int8_impl(prog, *int8_quant_params, ins_names);
}
// For the input of each input argument, we need to insert a
// capture operator to compute the scale and shift
std::size_t capture_arguments(program& prog,
const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func)
{
auto* mm = prog.get_main_module();
size_t num_quant_params = 0;
// the int8 quantization only support dot and convolution
std::set<std::string> op_names = {"dot", "convolution"};
std::set<std::string> op_names = {"convolution", "dot"};
std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end());
if(!std::includes(
op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
{
MIGRAPHX_THROW("CAPTURE_ARGUMENTS: input operator is not supported");
}
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(*mm))
{
if(not contains(ins_names, ins->name()))
{
continue;
}
auto inputs = ins->inputs();
std::vector<instruction_ref> new_args;
for(auto input : inputs)
{
instruction_ref new_ins{};
if(ins_map.count(input) > 0)
{
new_ins = ins_map[input];
}
else
{
new_ins = mm->insert_instruction(
std::next(input), op::capture{num_quant_params++, func}, input);
ins_map[input] = new_ins;
}
new_args.push_back(new_ins);
}
instruction::replace(ins, ins->get_operator(), ins->get_shape(), new_args);
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
}
return num_quant_params;
}
std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments_impl(program& prog, const target& t, const std::vector<std::string>& ins_names)
{
std::shared_ptr<std::vector<std::pair<float, float>>> int8_quant_params =
std::make_shared<std::vector<std::pair<float, float>>>();
std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>();
......@@ -545,7 +60,6 @@ capture_arguments_impl(program& prog, const target& t, const std::vector<std::st
auto calc_quant_params = [int8_quant_params, max_abs_vals, &t](std::size_t ins_index,
std::vector<argument> args) {
std::pair<float, float> param_pair{64.0f, 0.0f};
// scale and shift is need for only int8 type, and we do not
// consider shift, so set shift to 0
std::vector<float> vec_val;
......@@ -568,12 +82,56 @@ capture_arguments_impl(program& prog, const target& t, const std::vector<std::st
int8_quant_params->at(ins_index) = param_pair;
};
auto num_params = capture_arguments(prog, ins_names, calc_quant_params);
// pass to add capture argument op
std::size_t param_num = 0;
run_passes(prog, {capture_arguments_pass{ins_names, calc_quant_params, &param_num}});
int8_quant_params->resize(param_num, std::pair<float, float>(64.0f, 0.0f));
max_abs_vals->resize(param_num, 0.0f);
// use the calibration data to compute the quantization scale
auto capture_prog = prog;
capture_prog.compile(t);
// use all calibration data to run the program to calculate the
// quantization scale and shift
for(auto&& arg : calibration)
{
parameter_map m;
for(auto&& x : capture_prog.get_parameter_shapes())
{
if(arg.count(x.first) > 0)
{
assert(x.second == arg.at(x.first).get_shape());
m[x.first] = t.copy_to(arg.at(x.first));
}
else
{
m[x.first] = t.allocate(x.second);
}
}
capture_prog.eval(m);
}
int8_quant_params->resize(num_params, std::pair<float, float>(64.0f, 0.0f));
max_abs_vals->resize(num_params, 0.0f);
// print the quantization parameters in only the main module
if(enabled(MIGRAPHX_INT8_QUANTIZATION_PARAMS{}))
{
for(std::size_t i = 0; i < int8_quant_params->size(); ++i)
{
auto param = int8_quant_params->at(i);
std::cout << "ins_index = " << i << ", scale = " << param.first
<< ", shift = " << param.second << std::endl;
}
std::cout << std::endl;
}
return int8_quant_params;
run_passes(prog,
{quantize_int8_pass{ins_names, *int8_quant_params},
eliminate_common_subexpression{},
dead_code_elimination{},
simplify_reshapes{},
dead_code_elimination{},
simplify_qdq{},
dead_code_elimination{}});
}
} // namespace MIGRAPHX_INLINE_NS
......
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/target.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static void quantize_module(module& m, const std::vector<std::string>& ins_names)
{
for(auto ins : iterator_for(m))
{
// instructions are not in the set to be quantized
if(not(contains(ins_names, ins->name()) or contains(ins_names, "all")))
continue;
// skip return and convert instructions
if(contains({"@return", "convert"}, ins->name()))
continue;
if(ins->inputs().empty())
continue;
auto mod_inputs = ins->module_inputs();
auto s = ins->get_shape();
// Convert back to original type before quantizing the inputs
if(mod_inputs.empty())
{
auto r = m.insert_instruction(
std::next(ins), make_op("convert", {{"target_type", s.type()}}), ins);
m.replace_instruction(ins, r);
}
// Convert each of the inputs that are floating point to fp16
auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
auto input_type = input->get_shape().type();
if(input_type != shape::float_type and input_type != shape::double_type)
return input;
return m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::half_type}}), input);
});
// Replace inputs
m.replace_instruction(ins, ins->get_operator(), inputs, mod_inputs);
}
}
void quantize_fp16_pass::apply(module& m) const { quantize_module(m, ins_names); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/operation.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_int8.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/target.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <numeric>
#include <set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
static std::vector<shape::type_t>& get_quantizable_type()
{
static std::vector<shape::type_t> quantable_types = {
shape::float_type, shape::double_type, shape::half_type};
return quantable_types;
}
void quantize_int8_pass::apply(module& m) const // NOLINT
{
const auto& quantizable_types = get_quantizable_type();
for(auto ins : iterator_for(m))
{
if(ins->name() != "capture")
continue;
auto op_val = ins->get_operator().to_value();
assert(op_val.contains("ins_index"));
auto param_index = op_val.at("ins_index").to<std::size_t>();
auto param = quant_params[param_index];
auto input = ins->inputs().front();
auto s = input->get_shape();
if(contains(quantizable_types, s.type()) and s.type() != shape::int8_type)
{
auto zero_point = m.add_literal(static_cast<int8_t>(param.second));
auto scale = m.add_literal(literal({s.type()}, {1.0f / param.first}));
const auto& lens = s.lens();
scale =
m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), scale);
zero_point = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", lens}}), zero_point);
auto q_in =
m.insert_instruction(ins, make_op("quantizelinear"), input, scale, zero_point);
auto dq_in =
m.insert_instruction(ins, make_op("dequantizelinear"), q_in, scale, zero_point);
m.replace_instruction(ins, dq_in);
}
}
}
void capture_arguments_pass::apply(module& m) const // NOLINT
{
assert(param_index != nullptr);
for(auto ins : iterator_for(m))
{
if(not contains(ins_names, ins->name()))
{
continue;
}
auto inputs = ins->inputs();
std::vector<instruction_ref> new_args;
for(auto input : inputs)
{
auto new_in = m.insert_instruction(ins, op::capture{(*param_index)++, f}, input);
new_args.push_back(new_in);
}
m.replace_instruction(ins, ins->get_operator(), new_args);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -241,11 +241,11 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
// squeeze and transpose w
std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tran_sw = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sw);
auto tran_sw = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
// squeeze and transpose r
auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto tran_sr = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sr);
auto tran_sr = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
// initial hidden state
auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
......@@ -263,7 +263,7 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), sbias);
auto wrb = prog.insert_instruction(ins, make_op("add"), wb, rb);
bb = prog.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"dims", sih_lens}}), wrb);
ins, make_op("broadcast", {{"axis", 1}, {"out_lens", sih_lens}}), wrb);
}
instruction_ref hidden_out = prog.end();
......@@ -565,17 +565,17 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
// w matrix squeeze to 2-dim and do a transpose
std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tw = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sw);
auto tw = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
// r slide to two part, zr and h
auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto rzr = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2 * hs}}}), sr);
auto trzr = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), rzr);
auto trzr = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rzr);
auto rh = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), sr);
auto trh = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), rh);
auto trh = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rh);
// initial states
auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
......@@ -592,7 +592,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {3 * hs}}}), sbias);
bwb = prog.insert_instruction(
ins,
make_op("broadcast", {{"axis", 1}, {"dims", {bs, static_cast<size_t>(3 * hs)}}}),
make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(3 * hs)}}}),
wb);
auto rb_zr = prog.insert_instruction(
......@@ -605,11 +605,11 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
sbias);
brb_zr = prog.insert_instruction(
ins,
make_op("broadcast", {{"axis", 1}, {"dims", {bs, static_cast<size_t>(2 * hs)}}}),
make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(2 * hs)}}}),
rb_zr);
brb_h = prog.insert_instruction(
ins,
make_op("broadcast", {{"axis", 1}, {"dims", {bs, static_cast<size_t>(hs)}}}),
make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(hs)}}}),
rb_h);
}
......@@ -1038,11 +1038,11 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
std::vector<int64_t> perm{1, 0};
// w matrix, squeeze and transpose
auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tsw = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sw);
auto tsw = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
// r matrix, squeeze and transpose
auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto tsr = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sr);
auto tsr = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
// initial hidden state
auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
......@@ -1067,7 +1067,7 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
wrb = prog.insert_instruction(
ins,
make_op("broadcast", {{"axis", 1}, {"dims", {bs, 4 * static_cast<size_t>(hs)}}}),
make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, 4 * static_cast<size_t>(hs)}}}),
ub_wrb);
}
......@@ -1081,17 +1081,17 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto pphi = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), spph);
pphi_brcst = prog.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"dims", ic_lens}}), pphi);
ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphi);
auto ppho = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), spph);
ppho_brcst = prog.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"dims", ic_lens}}), ppho);
ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), ppho);
auto pphf = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), spph);
pphf_brcst = prog.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"dims", ic_lens}}), pphf);
ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphf);
}
long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens));
......
......@@ -239,6 +239,18 @@ struct stream_info
}
}
}
// move dangling parameter to the front so as not be removed
auto ins = std::next(last);
while(ins != p.end())
{
auto next = std::next(ins);
if(ins->name() == "@param")
{
p.move_instruction(ins, p.begin());
}
ins = next;
}
}
void set_stream(const partition& p, std::size_t n)
......@@ -510,6 +522,9 @@ void schedule::apply(module& p) const
if(enabled(MIGRAPHX_TRACE_COMPILE{}) or enabled(MIGRAPHX_TRACE_SCHEDULE{}))
{
p.annotate(std::cout, [&](auto ins) {
if(ins->name() == "@param" and not contains(si.weights, ins))
return;
std::cout << ":";
std::cout << " weight=" << si.weights.at(ins);
std::cout << " input={";
......@@ -550,11 +565,9 @@ void schedule::apply(module& p) const
{
for(auto i : si.get_recorded_instructions(ins))
{
if(not si.has_stream(i))
continue;
auto istream = si.get_stream(i);
if(stream == istream)
if(not si.has_stream(i) or si.get_stream(i) == stream)
continue;
// Create a new event if it hasn't been recorded
if(not contains(ins2wait, i))
{
......
......@@ -55,7 +55,7 @@ struct find_mul_conv
auto new_a = p.insert_instruction(
ins,
make_op("broadcast", {{"axis", 0}, {"dims", w_ins->get_shape().lens()}}),
make_op("broadcast", {{"axis", 0}, {"out_lens", w_ins->get_shape().lens()}}),
a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, make_op("mul"), new_a, w_ins);
auto new_conv = p.insert_instruction(
......@@ -120,7 +120,7 @@ struct find_mul_slice_conv
auto new_a = p.insert_instruction(
ins,
make_op("broadcast", {{"axis", 0}, {"dims", slice_w_ins->get_shape().lens()}}),
make_op("broadcast", {{"axis", 0}, {"out_lens", slice_w_ins->get_shape().lens()}}),
a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, make_op("mul"), new_a, slice_w_ins);
......@@ -989,8 +989,8 @@ struct find_split_transpose
}
// insert an transpose instruction
auto tr =
p.insert_instruction(std::next(input), make_op("transpose", {{"dims", perm}}), input);
auto tr = p.insert_instruction(
std::next(input), make_op("transpose", {{"permutation", perm}}), input);
// compute the axis in the slice
auto axis = any_cast<op::slice>(slc->get_operator()).axes.front();
......
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::unordered_set<std::string> get_quantizable_op_names()
{
static std::unordered_set<std::string> s = {"convolution", "dot"};
return s;
}
MIGRAPHX_PRED_MATCHER(has_same_value, instruction_ref ins)
{
if(ins->name() != "@literal")
return false;
bool all_same = false;
ins->get_literal().visit([&](auto s) {
all_same = std::all_of(s.begin() + 1, s.end(), [&](const auto& scale) {
return float_equal(scale, s.front());
});
});
return all_same;
}
struct match_find_quantizable_ops
{
static auto dequantizelinear_op(const std::string& name, const std::string& scale)
{
return match::name("dequantizelinear")(
match::arg(0)(match::skip(match::name("quantizelinear"))(match::any().bind(name))),
match::arg(1)(match::skip_broadcasts(has_same_value().bind(scale))),
match::arg(2)(match::skip_broadcasts(match::all_of(match::has_value(0)))));
}
auto matcher() const
{
return match::name(get_quantizable_op_names())(
match::arg(0)(dequantizelinear_op("x1", "scale1")),
match::arg(1)(dequantizelinear_op("x2", "scale2")));
}
void apply(module& m, match::matcher_result r) const
{
auto qop = r.result;
auto q1 = r.instructions["x1"];
auto q2 = r.instructions["x2"];
auto scale1 = r.instructions["scale1"];
auto scale2 = r.instructions["scale2"];
// Only INT8 type currently supported
if(q1->get_shape().type() != migraphx::shape::int8_type or
q2->get_shape().type() != migraphx::shape::int8_type)
return;
double scale;
visit_all(scale1->get_literal(), scale2->get_literal())(
[&](const auto s1, const auto s2) { scale = s1.front() * s2.front(); });
auto qop_args = qop->inputs();
qop_args.at(0) = q1;
qop_args.at(1) = q2;
instruction_ref dq;
instruction_ref dq_scale;
instruction_ref zero_point;
if(qop->name() == "convolution")
{
auto conv_val = qop->get_operator().to_value();
dq = m.insert_instruction(
qop, migraphx::make_op("quant_convolution", conv_val), qop_args);
}
else if(qop->name() == "dot")
{
auto dot_op = any_cast<op::dot>(qop->get_operator());
if(!(float_equal(dot_op.alpha, 1.0f) and float_equal(dot_op.beta, 0.0f)))
return;
if(qop_args.size() == 3)
qop_args.pop_back();
dq = m.insert_instruction(
qop, migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qop_args);
}
auto ins_type = qop->get_shape().type();
dq_scale = m.add_literal(literal({ins_type}, {scale}));
auto lens = dq->get_shape().lens();
auto scale_mb =
m.insert_instruction(qop, make_op("multibroadcast", {{"out_lens", lens}}), dq_scale);
dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, scale_mb);
m.replace_instruction(qop, dq);
}
};
bool compare_literals(instruction_ref ins1, instruction_ref ins2)
{
if(ins1->name() == "broadcast" or ins1->name() == "multibroadcast")
ins1 = ins1->inputs().front();
auto x = ins1->eval();
if(x.empty())
return false;
auto literal1 = ins1->get_literal();
if(ins2->name() == "broadcast" or ins2->name() == "multibroadcast")
ins2 = ins2->inputs().front();
auto y = ins2->eval();
if(y.empty())
return false;
auto literal2 = ins2->get_literal();
bool diff_shapes_equal_vals = false;
visit_all(ins1->get_literal(), ins2->get_literal())([&](const auto l1, const auto l2) {
diff_shapes_equal_vals =
std::all_of(
l1.begin() + 1, l1.end(), [&](auto v) { return float_equal(v, l1.front()); }) and
std::all_of(l2.begin(), l2.end(), [&](auto v) { return float_equal(v, l1.front()); });
});
return (x == y) or diff_shapes_equal_vals;
}
void remove_qdq_pairs(module& m)
{
for(auto ins : iterator_for(m))
{
auto args = ins->inputs();
for(auto&& arg : args)
{
if(arg->name() == "dequantizelinear")
{
auto q = arg->inputs().front();
if((q->name() == "quantizelinear") and
compare_literals(arg->inputs().at(1), q->inputs().at(1)) and
compare_literals(arg->inputs().at(2), q->inputs().at(2)))
{
instruction::replace_argument(ins, arg, q->inputs().front());
}
}
}
}
}
void simplify_qdq::apply(module& m) const
{
match::find_matches(m, match_find_quantizable_ops{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
remove_qdq_pairs(m);
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -153,7 +153,8 @@ struct find_transpose
}
else
{
p.replace_instruction(ins, make_op("transpose", {{"dims", dims}}), t->inputs().front());
p.replace_instruction(
ins, make_op("transpose", {{"permutation", dims}}), t->inputs().front());
}
}
};
......@@ -278,10 +279,12 @@ struct find_concat_transpose
std::vector<instruction_ref> inputs;
std::transform(
ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
return p.insert_instruction(ins, make_op("transpose", {{"dims", permutation}}), i);
return p.insert_instruction(
ins, make_op("transpose", {{"permutation", permutation}}), i);
});
auto concat = p.insert_instruction(ins, op, inputs);
auto t = p.insert_instruction(ins, make_op("transpose", {{"dims", ipermutation}}), concat);
auto t = p.insert_instruction(
ins, make_op("transpose", {{"permutation", ipermutation}}), concat);
assert(ins->get_shape().lens() == t->get_shape().lens());
p.replace_instruction(ins, t);
}
......@@ -418,7 +421,7 @@ struct find_resize
auto rsp_data = p.insert_instruction(
ins_rsp, migraphx::make_op("reshape", {{"dims", in_dims}}), in_rsp);
auto mb_rsp = p.insert_instruction(
ins_rsp, migraphx::make_op("multibroadcast", {{"output_lens", out_dims}}), rsp_data);
ins_rsp, migraphx::make_op("multibroadcast", {{"out_lens", out_dims}}), rsp_data);
auto std_mb = p.insert_instruction(ins, migraphx::make_op("contiguous"), mb_rsp);
std::vector<int64_t> rsp_dims(out_lens.begin(), out_lens.end());
p.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), std_mb);
......
......@@ -31,12 +31,29 @@ add_library(migraphx_cpu
set_target_properties(migraphx_cpu PROPERTIES EXPORT_NAME cpu)
rocm_set_soversion(migraphx_cpu ${MIGRAPHX_SO_VERSION})
set(MIGRAPHX_ENABLE_ZENDNN Off CACHE BOOL "")
find_package(Threads)
find_package(dnnl REQUIRED)
if(MIGRAPHX_ENABLE_ZENDNN)
find_path(ZENDNN_INC_PATH zendnn.hpp)
find_library(ZENDNN_LIB amdZenDNN)
find_library(BLIS_LIB blis)
else()
find_package(dnnl REQUIRED)
endif()
rocm_clang_tidy_check(migraphx_cpu)
if(MIGRAPHX_ENABLE_ZENDNN)
target_compile_definitions(migraphx_cpu PRIVATE -DMIGRAPHX_ENABLE_ZENDNN)
target_include_directories(migraphx_cpu PRIVATE ${ZENDNN_INC_PATH})
message(STATUS "ZENDNN_LIB: ${ZENDNN_LIB}")
target_link_libraries(migraphx_cpu PRIVATE ${BLIS_LIB})
target_link_libraries(migraphx_cpu PRIVATE ${ZENDNN_LIB})
else()
target_link_libraries(migraphx_cpu PRIVATE DNNL::dnnl)
endif()
target_link_libraries(migraphx_cpu PRIVATE migraphx Threads::Threads)
target_link_libraries(migraphx_cpu PRIVATE DNNL::dnnl)
find_package(OpenMP)
target_link_libraries(migraphx_cpu PUBLIC OpenMP::OpenMP_CXX)
......
......@@ -37,7 +37,10 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary>
dnnl::binary::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
return {to_dnnl_algo(algo), m.at(DNNL_ARG_SRC_0), m.at(DNNL_ARG_SRC_1), m.at(DNNL_ARG_DST)};
return {to_dnnl_algo(algo),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC_0)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC_1)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST))};
}
};
......
......@@ -11,7 +11,7 @@ struct dnnl_concat : dnnl_extend_op<dnnl_concat, dnnl::concat, op::concat>
std::vector<int> arg_map(int size) const
{
std::vector<int> result(size);
std::iota(result.begin(), result.end(), DNNL_ARG_MULTIPLE_SRC);
std::iota(result.begin(), result.end(), MIGRAPHX_DNNL_PREFIX(ARG_MULTIPLE_SRC));
return result;
}
// Custom desc class since its missing in dnnl
......@@ -28,9 +28,9 @@ struct dnnl_concat : dnnl_extend_op<dnnl_concat, dnnl::concat, op::concat>
for(auto i = 0; i < m.size() - 1; i++)
{
srcs.push_back(m.at(DNNL_ARG_MULTIPLE_SRC + i));
srcs.push_back(m.at(MIGRAPHX_DNNL_PREFIX(ARG_MULTIPLE_SRC) + i));
}
return {m.at(DNNL_ARG_DST), std::size_t(op.axis), srcs};
return {m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)), std::size_t(op.axis), srcs};
}
auto get_primitive_desc(const desc& d, const dnnl::primitive_attr& attr) const
......
......@@ -15,7 +15,10 @@ namespace cpu {
struct dnnl_convolution
: dnnl_extend_op<dnnl_convolution, dnnl::convolution_forward, op::convolution>
{
std::vector<int> arg_map(int) const { return {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS}; }
std::vector<int> arg_map(int) const
{
return {MIGRAPHX_DNNL_PREFIX(ARG_SRC), MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)};
}
shape adjust_shape(const shape& x, int i) const
{
......@@ -45,9 +48,9 @@ struct dnnl_convolution
std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end());
return {dnnl::prop_kind::forward_inference,
dnnl::algorithm::convolution_auto,
m.at(DNNL_ARG_SRC),
m.at(DNNL_ARG_WEIGHTS),
m.at(DNNL_ARG_DST),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)),
to_dnnl_dims(op.stride),
to_dnnl_dims(dilation),
to_dnnl_dims(padding_l),
......
......@@ -9,7 +9,10 @@ namespace cpu {
struct dnnl_deconvolution
: dnnl_extend_op<dnnl_deconvolution, dnnl::deconvolution_forward, op::deconvolution>
{
std::vector<int> arg_map(int) const { return {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS}; }
std::vector<int> arg_map(int) const
{
return {MIGRAPHX_DNNL_PREFIX(ARG_SRC), MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)};
}
shape adjust_shape(const shape& x, int i) const
{
......@@ -35,9 +38,9 @@ struct dnnl_deconvolution
dilation.begin(), dilation.end(), dilation.begin(), [](auto x) { return x - 1; });
return {dnnl::prop_kind::forward_inference,
dnnl::algorithm::deconvolution_direct,
m.at(DNNL_ARG_SRC),
m.at(DNNL_ARG_WEIGHTS),
m.at(DNNL_ARG_DST),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)),
to_dnnl_dims(op.stride),
to_dnnl_dims(dilation),
to_dnnl_dims(op.padding),
......
......@@ -2,6 +2,9 @@
#if defined(__GNUC__) && __GNUC__ <= 5
namespace std {
#ifdef MIGRAPHX_ENABLE_ZENDNN
namespace dnnl = zendnn;
#endif
template <>
struct hash<dnnl::algorithm>
{
......
......@@ -39,7 +39,7 @@ struct dnnl_eltwise : dnnl_op<dnnl_eltwise, dnnl::eltwise_forward>
{
return {dnnl::prop_kind::forward_inference,
to_dnnl_algo(algo),
m.at(DNNL_ARG_SRC_0),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC_0)),
alpha,
beta};
}
......
......@@ -13,13 +13,20 @@ namespace cpu {
struct dnnl_gemm : dnnl_extend_op<dnnl_gemm, dnnl::matmul, op::dot>
{
std::vector<int> arg_map(int) const { return {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS}; }
std::vector<int> arg_map(int) const
{
return {MIGRAPHX_DNNL_PREFIX(ARG_SRC),
MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS),
MIGRAPHX_DNNL_PREFIX(ARG_BIAS)};
}
void required(const check_shapes& cs) const { cs.not_broadcasted(); }
dnnl::matmul::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
return {m.at(DNNL_ARG_SRC), m.at(DNNL_ARG_WEIGHTS), m.at(DNNL_ARG_DST)};
return {m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST))};
}
};
......
......@@ -7,14 +7,26 @@
#include <migraphx/register_op.hpp>
#include <migraphx/check_shapes.hpp>
#include <unordered_map>
#include <dnnl.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/assert.hpp>
#ifdef MIGRAPHX_ENABLE_ZENDNN
#include <zendnn.hpp>
#else
#include <dnnl.hpp>
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {
#ifdef MIGRAPHX_ENABLE_ZENDNN
namespace dnnl = zendnn;
#define MIGRAPHX_CONCAT_PREFIX(b) ZENDNN_##b // NOLINT
#else
#define MIGRAPHX_CONCAT_PREFIX(b) DNNL_##b // NOLINT
#endif
#define MIGRAPHX_DNNL_PREFIX(b) MIGRAPHX_CONCAT_PREFIX(b) // NOLINT
struct dnnl_context
{
dnnl::engine engine;
......@@ -102,7 +114,8 @@ struct dnnl_op : auto_register_op<Derived>
static std::size_t get_binary_post_op_arg(std::size_t pos)
{
return DNNL_ARG_ATTR_MULTIPLE_POST_OP(pos) | DNNL_ARG_SRC_1; // NOLINT
return MIGRAPHX_DNNL_PREFIX(ARG_ATTR_MULTIPLE_POST_OP)(pos) | // NOLINT
MIGRAPHX_DNNL_PREFIX(ARG_SRC_1); // NOLINT
}
static std::vector<shape> to_shapes(const std::vector<argument>& args)
......@@ -117,14 +130,18 @@ struct dnnl_op : auto_register_op<Derived>
{
auto desc = prim.get_primitive_desc();
const char* str = nullptr;
#ifdef MIGRAPHX_ENABLE_ZENDNN
zendnn_primitive_desc_query(desc, zendnn_query_impl_info_str, 0, &str);
#else
dnnl_primitive_desc_query(desc, dnnl_query_impl_info_str, 0, &str);
#endif
return str == nullptr ? "" : str;
}
// Map arg index to arg in dnnl
std::vector<int> arg_map(int size) const
{
std::vector<int> result(size);
std::iota(result.begin(), result.end(), DNNL_ARG_SRC_0);
std::iota(result.begin(), result.end(), MIGRAPHX_DNNL_PREFIX(ARG_SRC_0));
return result;
}
shape base_adjust_shape(const shape& s) const
......@@ -183,8 +200,9 @@ struct dnnl_op : auto_register_op<Derived>
{
const auto& self = static_cast<const Derived&>(*this);
std::unordered_map<int, dnnl::memory::desc> result;
result[DNNL_ARG_DST] = to_dnnl_memory_desc(self.adjust_shape(output_shape, inputs.size()));
auto m = create_arg_map(inputs.size());
result[MIGRAPHX_DNNL_PREFIX(ARG_DST)] =
to_dnnl_memory_desc(self.adjust_shape(output_shape, inputs.size()));
auto m = create_arg_map(inputs.size());
assert(m.size() >= inputs.size());
for(int i = 0; i < inputs.size(); i++)
{
......@@ -201,7 +219,7 @@ struct dnnl_op : auto_register_op<Derived>
if(contains(op.algo, "binary_add"))
{
auto desc = m.at(arg);
if(desc == m.at(DNNL_ARG_DST))
if(desc == m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)))
po.append_sum(1.0f);
else
po.append_binary(to_dnnl_algo(op.algo), m.at(arg));
......@@ -328,7 +346,8 @@ struct dnnl_op : auto_register_op<Derived>
}
#endif
std::unordered_map<int, dnnl::memory> m;
m[DNNL_ARG_DST] = to_dnnl_memory(md.at(DNNL_ARG_DST), args.back());
m[MIGRAPHX_DNNL_PREFIX(ARG_DST)] =
to_dnnl_memory(md.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)), args.back());
for(int i = 0; i < args.size() - 1; i++)
m[arg_lookup[i]] = to_dnnl_memory(md.at(arg_lookup[i]), args[i]);
prim.execute(get_dnnl_context().stream, m);
......
......@@ -31,7 +31,7 @@ struct dnnl_layernorm : dnnl_op<dnnl_layernorm, dnnl::layer_normalization_forwar
get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
return {dnnl::prop_kind::forward_inference,
m.at(DNNL_ARG_SRC),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)),
1e-12f,
dnnl::normalization_flags::none};
}
......
......@@ -12,7 +12,7 @@ struct dnnl_logsoftmax : dnnl_extend_op<dnnl_logsoftmax, dnnl::logsoftmax_forwar
get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
int axis = this->op.axis;
return {dnnl::prop_kind::forward_inference, m.at(DNNL_ARG_SRC_0), axis};
return {dnnl::prop_kind::forward_inference, m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC_0)), axis};
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment