Commit 4a39a0f7 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into add-conv_bn_add-test

parents 5564172e bb827865
...@@ -133,18 +133,11 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -133,18 +133,11 @@ struct parse_pooling : op_parser<parse_pooling>
slice_end.begin(), slice_end.begin(),
[](auto i, auto j) { return i + j; }); [](auto i, auto j) { return i + j; });
} }
values["padding"] = std::vector<size_t>(paddings.begin(), paddings.end());
check_asym_padding(info, l0, paddings, values, count_include_pad, pad_val); check_asym_padding(info, l0, paddings, values, count_include_pad, pad_val);
in_lens = l0->get_shape().lens();
for(size_t i = 0; i < kdims; i++)
{
if(values["lengths"][i].to<int64_t>() >
in_lens[i + 2] + 2 * values["padding"][i].to<int64_t>())
{
MIGRAPHX_THROW("PARSE_POOLING: kernel shape is too large");
}
}
op.from_value(values); op.from_value(values);
auto l1 = info.add_instruction(op, l0); auto l1 = info.add_instruction(op, l0);
if(!slice_start.empty()) if(!slice_start.empty())
{ {
......
...@@ -12,83 +12,51 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear> ...@@ -12,83 +12,51 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
{ {
std::vector<op_desc> operators() const { return {{"QuantizeLinear"}}; } std::vector<op_desc> operators() const { return {{"QuantizeLinear"}}; }
// y = saturate(round(x / y_scale) + zero_point)
instruction_ref parse(const op_desc& opd, instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/, const onnx_parser& /*parser*/,
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const const std::vector<instruction_ref>& args) const
{ {
auto quant_type = shape::uint8_type;
int nargs = args.size();
int max_quant = 255;
int min_quant = 0;
if(nargs == 3)
quant_type = args[2]->get_shape().type();
if(quant_type == shape::int8_type)
{
max_quant = 127;
min_quant = -128;
}
auto max_arg = info.add_literal(max_quant);
auto min_arg = info.add_literal(min_quant);
int axis = 1; int axis = 1;
if(contains(info.attributes, "axis")) if(contains(info.attributes, "axis"))
axis = info.attributes.at("axis").i(); axis = info.attributes.at("axis").i();
auto input_lens = args[0]->get_shape().lens(); auto input_lens = args[0]->get_shape().lens();
int n_dim = static_cast<int>(input_lens.size()); auto n_dim = input_lens.size();
auto scale = args[1]; instruction_ref y_scale;
if(not(scale->get_shape().elements() == 1)) if(args[1]->get_shape().elements() != 1)
{ {
axis = tune_axis(n_dim, axis, opd.op_name); auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
scale = info.add_instruction( y_scale = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), scale); make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]);
}
else
{
y_scale = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
args[1]);
} }
auto div = info.add_broadcastable_binary_op("div", args[0], scale); if(args.size() == 3)
auto div_round = info.add_instruction(make_op("round"), div);
auto add_zero_point = div_round;
if(nargs == 3)
{ {
auto zero_point = args[2]; auto y_zero_point = args[2];
if(not(zero_point->get_shape().elements() == 1)) if(y_zero_point->get_shape().elements() != 1)
{ {
axis = tune_axis(n_dim, axis, opd.op_name); auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
zero_point = info.add_instruction( y_zero_point = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), zero_point); make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}),
y_zero_point);
}
else
{
y_zero_point = info.add_instruction(
make_op("multibroadcast", {{"out_lens", input_lens}}), y_zero_point);
} }
zero_point = info.add_instruction( return info.add_instruction(make_op("quantizelinear"), args[0], y_scale, y_zero_point);
make_op("convert", {{"target_type", shape::int32_type}}), zero_point);
add_zero_point = info.add_instruction(
make_op("convert", {{"target_type", shape::int32_type}}), add_zero_point);
add_zero_point = info.add_broadcastable_binary_op("add", add_zero_point, zero_point);
}
auto s = add_zero_point->get_shape();
const auto& lens = s.lens();
std::vector<int64_t> out_lens(lens.begin(), lens.end());
if(min_arg->get_shape() != s)
{
min_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}),
min_arg);
}
if(max_arg->get_shape() != s)
{
max_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}),
max_arg);
} }
auto saturated = info.add_instruction(make_op("clip"), add_zero_point, min_arg, max_arg); return info.add_instruction(make_op("quantizelinear"), args[0], y_scale);
return info.add_instruction(make_op("convert", {{"target_type", quant_type}}), saturated);
} }
}; };
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx/checks.hpp>
#include <random>
#include <set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_randomnormal_ops : op_parser<parse_randomnormal_ops>
{
const std::set<shape::type_t> valid_types = {
shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"RandomNormal"}, {"RandomNormalLike"}}; }
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
int dtype = 1;
bool use_dtype = false;
if(contains(info.attributes, "dtype"))
{
dtype = info.attributes.at("dtype").i();
use_dtype = true;
}
shape::type_t out_type = get_type(dtype);
if(not contains(valid_types, out_type))
MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) +
". Valid types are 1 (float), 10 (half), and 11 (double).");
float mean = 0.0;
if(contains(info.attributes, "mean"))
mean = info.attributes.at("mean").f();
float scale = 1.0;
if(contains(info.attributes, "scale"))
scale = info.attributes.at("scale").f();
float seed = static_cast<float>(
std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
seed = info.attributes.at("seed").f();
shape out_shape;
if(contains(info.attributes, "shape"))
{
// RandomNormal:
// output type and shape must come from attributes
std::vector<int> out_lens;
literal ls = parser.parse_value(info.attributes.at("shape"));
ls.visit([&](auto s) { out_lens.assign(s.begin(), s.end()); });
out_shape = shape{out_type, out_lens};
}
else if(args.size() == 1)
{
// RandomNormalLike:
// output type and shape are the same as the input's by default
// dtype is used instead when attribute is set
if(not contains(valid_types, args[0]->get_shape().type()))
MIGRAPHX_THROW(opd.op_name + ": invalid output type: " +
std::to_string(args[0]->get_shape().type()) +
". Valid types are float, half, and double.");
out_shape =
use_dtype ? shape{out_type, args[0]->get_shape().lens()} : args[0]->get_shape();
}
else
{
MIGRAPHX_THROW(opd.op_name +
": cannot deduce shape without shape attribute or argument.");
}
std::mt19937 gen(seed);
std::normal_distribution<> d(mean, scale);
std::vector<double> rand_vals(out_shape.elements());
std::generate(rand_vals.begin(), rand_vals.end(), [&]() { return d(gen); });
return info.add_literal(literal{out_shape, rand_vals});
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx/checks.hpp>
#include <random>
#include <set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops>
{
const std::set<shape::type_t> valid_types = {
shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"RandomUniform"}, {"RandomUniformLike"}}; }
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
int dtype = 1;
bool use_dtype = false;
if(contains(info.attributes, "dtype"))
{
dtype = info.attributes.at("dtype").i();
use_dtype = true;
}
shape::type_t out_type = get_type(dtype);
if(not contains(valid_types, out_type))
MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) +
". Valid types are 1 (float), 10 (half), and 11 (double).");
float high = 1.0;
if(contains(info.attributes, "high"))
high = info.attributes.at("high").f();
float low = 0.0;
if(contains(info.attributes, "low"))
low = info.attributes.at("low").f();
float seed = static_cast<float>(
std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
seed = info.attributes.at("seed").f();
shape out_shape;
if(contains(info.attributes, "shape"))
{
// RandomUniform:
// output type and shape must come from attributes
std::vector<int> out_lens;
literal ls = parser.parse_value(info.attributes.at("shape"));
ls.visit([&](auto s) { out_lens.assign(s.begin(), s.end()); });
out_shape = shape{out_type, out_lens};
}
else if(args.size() == 1)
{
// RandomUniformLike:
// output type and shape are the same as the input by default
// dtype is used instead when attribute is set
if(not contains(valid_types, args[0]->get_shape().type()))
MIGRAPHX_THROW(opd.op_name + ": invalid output type: " +
std::to_string(args[0]->get_shape().type()) +
". Valid types are float, half, and double.");
out_shape =
use_dtype ? shape{out_type, args[0]->get_shape().lens()} : args[0]->get_shape();
}
else
{
MIGRAPHX_THROW(opd.op_name +
": cannot deduce shape without shape attribute or argument.");
}
std::mt19937 gen(seed);
std::uniform_real_distribution<> d(high, low);
std::vector<double> rand_vals(out_shape.elements());
std::generate(rand_vals.begin(), rand_vals.end(), [&]() { return d(gen); });
return info.add_literal(literal{out_shape, rand_vals});
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -55,7 +55,7 @@ const auto& get_original_idx_op(const std::string& mode) ...@@ -55,7 +55,7 @@ const auto& get_original_idx_op(const std::string& mode)
}}, }},
{"align_corners", {"align_corners",
[=](std::size_t l_in, std::size_t l_out, std::size_t idx, double) { [=](std::size_t l_in, std::size_t l_out, std::size_t idx, double) {
return 1.0 * idx * (l_in - 1.0) / (l_out - 1.0); return (l_out == 1) ? 0.0 : (1.0 * idx * (l_in - 1.0) / (l_out - 1.0));
}}, }},
{"asymmetric", {"asymmetric",
[=](std::size_t, std::size_t, std::size_t idx, double scale) { return idx / scale; }}, [=](std::size_t, std::size_t, std::size_t idx, double scale) { return idx / scale; }},
...@@ -71,6 +71,96 @@ const auto& get_original_idx_op(const std::string& mode) ...@@ -71,6 +71,96 @@ const auto& get_original_idx_op(const std::string& mode)
return idx_ops.at(mode); return idx_ops.at(mode);
} }
static std::vector<int>
calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& vvv_ind,
int i_dim,
const std::vector<std::vector<std::size_t>>& vec_dims,
const shape& in_s)
{
if(i_dim == vvv_ind.size())
{
std::vector<int> vec_ind;
vec_ind.resize(vec_dims.size());
std::transform(vec_dims.begin(), vec_dims.end(), vec_ind.begin(), [&](auto idx) {
return static_cast<int>(in_s.index(idx));
});
return vec_ind;
}
const auto& vv_ind = vvv_ind[i_dim];
const auto& vv_lo = vv_ind.at(0);
std::vector<std::vector<std::size_t>> vec_dims1;
for(std::size_t start = 0; start < vec_dims.size(); start += vv_lo.size())
{
std::transform(vv_lo.begin(),
vv_lo.end(),
vec_dims.begin() + start,
std::back_inserter(vec_dims1),
[](auto i, auto dim) {
dim.push_back(i);
return dim;
});
}
const auto& vv_hi = vv_ind.at(1);
for(std::size_t start = 0; start < vec_dims.size(); start += vv_lo.size())
{
std::transform(vv_hi.begin(),
vv_hi.end(),
vec_dims.begin() + start,
std::back_inserter(vec_dims1),
[](auto i, auto dim) {
dim.push_back(i);
return dim;
});
}
return calc_neighbor_points(vvv_ind, i_dim + 1, vec_dims1, in_s);
}
static std::string get_coord_trans_mode(const onnx_parser::attribute_map& attr)
{
std::string coord_trans_mode = "half_pixel";
if(contains(attr, "coordinate_transformation_mode"))
{
coord_trans_mode = attr.at("coordinate_transformation_mode").s();
// does not support transformation mode "tf_crop_and_resize"
if(coord_trans_mode == "tf_crop_and_resize")
{
MIGRAPHX_THROW("PARSE_RESIZE: \"tf_crop_and_resize\" mode is not supported!");
}
}
return coord_trans_mode;
}
static std::string get_mode(const onnx_parser::attribute_map& attr)
{
std::string mode = "nearest";
if(contains(attr, "mode"))
{
mode = attr.at("mode").s();
if(mode != "nearest" and mode != "linear")
{
MIGRAPHX_THROW("PARSE_RESIZE: only nearest and linear modes are supported!");
}
}
return mode;
}
static std::string get_nearest_mode(const onnx_parser::attribute_map& attr)
{
std::string nearest_mode = "round_prefer_floor";
if(contains(attr, "nearest_mode"))
{
nearest_mode = attr.at("nearest_mode").s();
}
return nearest_mode;
}
struct parse_resize : op_parser<parse_resize> struct parse_resize : op_parser<parse_resize>
{ {
std::vector<op_desc> operators() const { return {{"Resize"}}; } std::vector<op_desc> operators() const { return {{"Resize"}}; }
...@@ -80,42 +170,20 @@ struct parse_resize : op_parser<parse_resize> ...@@ -80,42 +170,20 @@ struct parse_resize : op_parser<parse_resize>
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
std::string coord_trans_mode = "half_pixel"; // coord transform mode
if(contains(info.attributes, "coordinate_transformation_mode")) std::string coord_trans_mode = get_coord_trans_mode(info.attributes);
{
coord_trans_mode = info.attributes.at("coordinate_transformation_mode").s();
// does not support transformation mode "tf_crop_and_resize"
if(coord_trans_mode == "tf_crop_and_resize")
{
MIGRAPHX_THROW("PARSE_RESIZE: \"tf_crop_and_resize\" mode is not supported!");
}
}
// mode: only nearest mode is supported for now // mode: only nearest and linear modes are supported for now
if(contains(info.attributes, "mode")) std::string mode = get_mode(info.attributes);
{
auto mode = info.attributes.at("mode").s();
if(mode != "nearest")
{
MIGRAPHX_THROW("PARSE_RESIZE: only nearest mode is supported!");
}
}
// nearest mode // nearest mode
std::string nearest_mode = "round_prefer_floor"; std::string nearest_mode = get_nearest_mode(info.attributes);
if(contains(info.attributes, "nearest_mode"))
{
nearest_mode = info.attributes.at("nearest_mode").s();
}
// check exclude_outside, only support 0 // check exclude_outside, only support 0
if(contains(info.attributes, "exclude_outside")) if(contains(info.attributes, "exclude_outside") and
info.attributes.at("exclude_outside").i() == 1)
{ {
int exclude_outside = info.attributes.at("exclude_outside").i(); MIGRAPHX_THROW("PARSE_RESIZE: exclude_outside 1 is not supported!");
if(exclude_outside == 1)
{
MIGRAPHX_THROW("PARSE_RESIZE: exclude_outside 1 is not supported!");
}
} }
// input data shape info // input data shape info
...@@ -128,74 +196,164 @@ struct parse_resize : op_parser<parse_resize> ...@@ -128,74 +196,164 @@ struct parse_resize : op_parser<parse_resize>
// scale // scale
std::vector<double> vec_scale; std::vector<double> vec_scale;
// output size is specified in input, so use it as output size for(const auto& arg : args)
if(args.size() == 4 and args.back()->name() != "undefined")
{ {
auto arg_out_s = args[3]->eval(); if(arg->name() == "undefined" or arg == args.front())
check_arg_empty(arg_out_s, "PARSE_RESIZE: dynamic output size is not supported!");
arg_out_s.visit([&](auto ol) { out_lens.assign(ol.begin(), ol.end()); });
if(out_lens.size() != in_lens.size())
{ {
MIGRAPHX_THROW("PARSE_RESIZE: specified output size does not match input size"); continue;
} }
// compute the scale // skipped empty input
vec_scale.resize(in_lens.size()); auto lens = arg->get_shape().lens();
std::transform(in_lens.begin(), if(lens.empty())
in_lens.end(), {
out_lens.begin(), continue;
vec_scale.begin(), }
[](auto iss, auto oss) { return 1.0 * oss / iss; });
}
// need to compute the output lens from input
else
{
auto arg_scale = args[2]->eval();
check_arg_empty(arg_scale, "PARSE_RESIZE: dynamic input scale is not supported!");
arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); }); auto type = arg->get_shape().type();
if(in_lens.size() != vec_scale.size()) // output size
if(type == shape::int64_type)
{ {
MIGRAPHX_THROW("PARSE_RESIZE: ranks of input and scale are different!"); auto arg_out_s = arg->eval();
check_arg_empty(arg_out_s, "PARSE_RESIZE: dynamic output size is not supported!");
arg_out_s.visit([&](auto ol) { out_lens.assign(ol.begin(), ol.end()); });
if(out_lens.size() != in_lens.size())
{
MIGRAPHX_THROW("PARSE_RESIZE: specified output size does not match input size");
}
// compute the scale
vec_scale.resize(in_lens.size());
std::transform(in_lens.begin(),
in_lens.end(),
out_lens.begin(),
vec_scale.begin(),
[](auto iss, auto oss) { return 1.0 * oss / iss; });
} }
else
{
// scale input
if(lens[0] == in_lens.size())
{
auto arg_scale = arg->eval();
check_arg_empty(arg_scale,
"PARSE_RESIZE: dynamic input scale is not supported!");
std::transform( arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); });
in_lens.begin(), if(in_lens.size() != vec_scale.size())
in_lens.end(), {
vec_scale.begin(), MIGRAPHX_THROW("PARSE_RESIZE: ranks of input and scale are different!");
out_lens.begin(), }
[&](auto idx, auto scale) { return static_cast<std::size_t>(idx * scale); });
std::transform(in_lens.begin(),
in_lens.end(),
vec_scale.begin(),
out_lens.begin(),
[&](auto idx, auto scale) {
return static_cast<std::size_t>(idx * scale);
});
}
}
} }
shape out_s{in_s.type(), out_lens}; shape out_s{in_s.type(), out_lens};
std::vector<int> ind(out_s.elements()); std::size_t out_elements = out_s.elements();
auto idx_op = get_original_idx_op(coord_trans_mode);
// map out_idx to in_idx // reshape input to one-dimension
auto nearest_op = get_nearest_op(nearest_mode); std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())};
auto idx_op = get_original_idx_op(coord_trans_mode); args[0] = info.make_contiguous(args[0]);
auto rsp = info.add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]);
if(mode == "nearest")
{
std::vector<int> ind(out_elements);
// map out_idx to in_idx
auto nearest_op = get_nearest_op(nearest_mode);
shape_for_each(out_s, [&](auto idx) {
auto in_idx = idx;
for(auto ii = 0; ii < in_lens.size(); ++ii)
{
auto idx_val = idx_op(in_lens[ii], out_lens[ii], idx[ii], vec_scale[ii]);
in_idx[ii] = nearest_op(in_lens[ii], idx_val);
}
ind[out_s.index(idx)] = static_cast<int64_t>(in_s.index(in_idx));
});
shape ind_s{shape::int32_type, out_lens};
auto ins_ind = info.add_literal(literal(ind_s, ind));
return info.add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind);
}
// linear mode
else
{
auto nearest_floor = get_nearest_op("floor");
auto nearest_ceil = get_nearest_op("ceil");
shape_for_each(out_s, [&](auto idx) { // get the number of dimensions
auto in_idx = idx; std::size_t n_dim = out_lens.size();
for(auto ii = 0; ii < in_lens.size(); ++ii) std::vector<std::vector<std::size_t>> vv_ind(2, std::vector<std::size_t>(out_elements));
std::vector<std::vector<std::vector<std::size_t>>> vvv_ind(n_dim, vv_ind);
std::vector<std::vector<float>> delta(n_dim, std::vector<float>(out_elements));
shape_for_each(out_s, [&](auto idx) {
auto in_idx = idx;
auto out_idx = out_s.index(idx);
for(auto ii = 0; ii < in_lens.size(); ++ii)
{
auto idx_val = idx_op(in_lens[ii], out_lens[ii], idx[ii], vec_scale[ii]);
vvv_ind[ii][0][out_idx] = nearest_floor(in_lens[ii], idx_val);
vvv_ind[ii][1][out_idx] = nearest_ceil(in_lens[ii], idx_val);
delta[ii][out_idx] = idx_val - vvv_ind[ii][0][out_idx];
}
});
std::vector<std::vector<std::size_t>> vec_dims(out_elements);
auto ind = calc_neighbor_points(vvv_ind, 0, vec_dims, in_s);
auto ind_lens = out_lens;
ind_lens[0] *= (std::size_t{1} << n_dim);
shape ind_s{shape::int32_type, ind_lens};
auto ins_ind = info.add_literal(literal(ind_s, ind));
auto data = info.add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind);
auto dim_lens = out_lens;
dim_lens[0] *= (std::size_t{1} << (n_dim - 1));
for(std::size_t i = 0; i < n_dim; ++i)
{ {
auto idx_val = idx_op(in_lens[ii], out_lens[ii], in_idx[ii], vec_scale[ii]); shape dim_s{shape::float_type, dim_lens};
in_idx[ii] = nearest_op(in_lens[ii], idx_val); const auto& dim_delta = delta[n_dim - i - 1];
} std::vector<float> delta_data;
for(std::size_t j = 0; j < dim_lens[0] / out_lens[0]; ++j)
{
delta_data.insert(delta_data.begin(), dim_delta.begin(), dim_delta.end());
}
auto ins_delta = info.add_literal(dim_s, delta_data);
ind[out_s.index(idx)] = static_cast<int64_t>(in_s.index(in_idx)); // slice the data
}); int64_t slc_stride = static_cast<int64_t>(dim_lens[0]);
auto low = info.add_instruction(
make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {slc_stride}}}),
data);
auto hi = info.add_instruction(
make_op("slice",
{{"axes", {0}}, {"starts", {slc_stride}}, {"ends", {2 * slc_stride}}}),
data);
auto diff = info.add_instruction(make_op("sub"), hi, low);
auto ddf = info.add_instruction(make_op("mul"), diff, ins_delta);
data = info.add_instruction(make_op("add"), ddf, low);
dim_lens[0] /= 2;
}
// reshape input to one-dimension return data;
std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())}; }
shape ind_s{shape::int32_type, out_lens};
auto arg_cont = info.make_contiguous(args[0]);
auto rsp = info.add_instruction(make_op("reshape", {{"dims", rsp_lens}}), arg_cont);
auto ins_ind = info.add_literal(literal(ind_s, ind));
return info.add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind);
} }
}; };
} // namespace onnx } // namespace onnx
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -35,9 +35,9 @@ struct parse_selu : op_parser<parse_selu> ...@@ -35,9 +35,9 @@ struct parse_selu : op_parser<parse_selu>
if(lens != std::vector<std::size_t>{1}) if(lens != std::vector<std::size_t>{1})
{ {
l_alpha = l_alpha =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), l_alpha); info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), l_alpha);
l_gamma = l_gamma =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), l_gamma); info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), l_gamma);
} }
auto sign_x = info.add_instruction(make_op("sign"), args[0]); auto sign_x = info.add_instruction(make_op("sign"), args[0]);
......
...@@ -20,18 +20,15 @@ struct parse_slice : op_parser<parse_slice> ...@@ -20,18 +20,15 @@ struct parse_slice : op_parser<parse_slice>
{ {
op::slice op; op::slice op;
std::vector<int64_t> steps;
// slice can have up to 5 inputs, we first check the 5th one // slice can have up to 5 inputs, we first check the 5th one
// to decide whether MIGRAPHX can handle this slice // to decide whether MIGRAPHX can handle this slice
if(args.size() == 5) if(args.size() == 5)
{ {
migraphx::argument step_arg = args.back()->eval(); migraphx::argument step_arg = args.back()->eval();
check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice"); check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice");
std::vector<int> steps;
step_arg.visit([&](auto s) { steps.assign(s.begin(), s.end()); }); step_arg.visit([&](auto s) { steps.assign(s.begin(), s.end()); });
if(!std::all_of(steps.begin(), steps.end(), [](auto s) { return s == 1; }))
{
MIGRAPHX_THROW("PARSE_SLICE: cannot handle step other than 1");
}
} }
if(args.size() >= 4) if(args.size() >= 4)
...@@ -77,7 +74,38 @@ struct parse_slice : op_parser<parse_slice> ...@@ -77,7 +74,38 @@ struct parse_slice : op_parser<parse_slice>
op.axes = axes; op.axes = axes;
} }
return info.add_instruction(op, args[0]); std::vector<int64_t> raxes;
assert(steps.empty() or steps.size() == op.axes.size());
assert(op.axes.size() == op.starts.size());
assert(op.axes.size() == op.ends.size());
for(auto i : range(steps.size()))
{
if(steps[i] >= 0)
continue;
op.starts[i] += 1;
if(op.starts[i] == 0)
op.starts[i] = INT_MAX;
op.ends[i] += 1;
raxes.push_back(op.axes[i]);
std::swap(op.starts[i], op.ends[i]);
}
auto ins = info.add_instruction(op, args[0]);
if(not raxes.empty())
ins = info.add_instruction(make_op("reverse", {{"axes", raxes}}), ins);
if(std::any_of(steps.begin(), steps.end(), [](auto s) { return std::abs(s) != 1; }))
{
std::vector<int64_t> nsteps;
std::transform(steps.begin(), steps.end(), std::back_inserter(nsteps), [](auto s) {
return std::abs(s);
});
return ins = info.add_instruction(
make_op("step", {{"axes", op.axes}, {"steps", nsteps}}), ins);
}
else
return ins;
} }
}; };
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/common.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_thresholdedrelu : op_parser<parse_thresholdedrelu>
{
std::vector<op_desc> operators() const { return {{"ThresholdedRelu"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
float alpha = 1.0;
if(contains(info.attributes, "alpha"))
alpha = parser.parse_value(info.attributes.at("alpha")).at<float>();
auto x_shape = args[0]->get_shape();
auto lit_zero = info.add_literal(migraphx::literal{migraphx::shape{x_shape.type()}, {0}});
auto lit_alpha =
info.add_literal(migraphx::literal{migraphx::shape{x_shape.type()}, {alpha}});
auto mb_zero = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x_shape.lens()}}), lit_zero);
auto mb_alpha = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x_shape.lens()}}), lit_alpha);
auto condition = info.add_instruction(migraphx::make_op("greater"), args[0], mb_alpha);
return info.add_instruction(migraphx::make_op("where"), condition, args[0], mb_zero);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_topk : op_parser<parse_topk>
{
std::vector<op_desc> operators() const { return {{"TopK"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
int64_t k = 0;
if(args.size() == 2)
{
auto arg_k = args.at(1)->eval();
check_arg_empty(arg_k, "PARSE_TopK: k input must be constant");
k = arg_k.at<int>();
}
else if(contains(info.attributes, "k"))
{
k = info.attributes.at("k").i();
}
bool largest = true;
if(contains(info.attributes, "largest"))
{
largest = static_cast<bool>(info.attributes.at("largest").i());
}
int64_t axis = -1;
if(contains(info.attributes, "axis"))
{
axis = parser.parse_value(info.attributes.at("axis")).at<int>();
}
auto topk_ret = info.add_instruction(
make_op("topk", {{"k", k}, {"axis", axis}, {"largest", largest}}), args.at(0));
auto ret_val = info.add_instruction(make_op("get_tuple_elem", {{"index", 0}}), topk_ret);
auto ret_ind = info.add_instruction(make_op("get_tuple_elem", {{"index", 1}}), topk_ret);
return {ret_val, ret_ind};
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp> #include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -21,7 +22,21 @@ struct parse_transpose : op_parser<parse_transpose> ...@@ -21,7 +22,21 @@ struct parse_transpose : op_parser<parse_transpose>
auto&& perm_vals = info.attributes["perm"].ints(); auto&& perm_vals = info.attributes["perm"].ints();
perm = std::vector<int64_t>(perm_vals.begin(), perm_vals.end()); perm = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
} }
return info.add_instruction(make_op("transpose", {{"dims", perm}}), args.front());
// if perm is empty, use the default value
auto n_dim = args.front()->get_shape().lens().size();
if(perm.empty())
{
perm.resize(n_dim);
std::iota(perm.rbegin(), perm.rend(), 0);
}
if(perm.size() != n_dim)
{
MIGRAPHX_THROW("PARSE_TRANSPOSE: perm and input have diffferent number of dims!");
}
return info.add_instruction(make_op("transpose", {{"permutation", perm}}), args.front());
} }
}; };
......
#include <migraphx/onnx/op_parser.hpp> #include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/common.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
namespace migraphx { namespace migraphx {
...@@ -16,45 +17,28 @@ struct parse_where : op_parser<parse_where> ...@@ -16,45 +17,28 @@ struct parse_where : op_parser<parse_where>
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
auto cond = auto lens =
info.add_instruction(make_op("convert", {{"target_type", shape::int32_type}}), args[0]); compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens());
auto lens = compute_broadcasted_lens(cond->get_shape().lens(), args[1]->get_shape().lens()); lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens()); if(args[0]->get_shape().lens() != lens)
if(cond->get_shape().lens() != lens)
{ {
cond = info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), cond); args[0] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[0]);
} }
if(args[1]->get_shape().lens() != lens) if(args[1]->get_shape().lens() != lens)
{ {
args[1] = args[1] =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), args[1]); info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[1]);
} }
if(args[2]->get_shape().lens() != lens) if(args[2]->get_shape().lens() != lens)
{ {
args[2] = args[2] =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), args[2]); info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[2]);
} }
// compute index return info.add_instruction(make_op("where"), args[0], args[1], args[2]);
auto elem_num = args[1]->get_shape().elements();
// concatenation of input data
auto concat_data = info.add_instruction(make_op("concat", {{"axis", 0}}), args[2], args[1]);
std::vector<int64_t> dims = {static_cast<int64_t>(2 * elem_num)};
auto rsp_data = info.add_instruction(make_op("reshape", {{"dims", dims}}), concat_data);
std::vector<int> ind(elem_num);
std::iota(ind.begin(), ind.end(), 0);
shape ind_s{shape::int32_type, lens};
auto l_ind = info.add_literal(literal(ind_s, ind));
std::vector<int> offset(elem_num, elem_num);
auto l_offset = info.add_literal(literal({shape::int32_type, lens}, offset));
auto ins_offset = info.add_instruction(make_op("mul"), l_offset, cond);
auto ins_ind = info.add_instruction(make_op("add"), ins_offset, l_ind);
return info.add_instruction(make_op("gather", {{"axis", 0}}), rsp_data, ins_ind);
} }
}; };
......
...@@ -40,7 +40,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval) ...@@ -40,7 +40,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
std::size_t size = s.bytes(); std::size_t size = s.bytes();
if(size == 0) if(size == 0)
return false; return false;
std::size_t element_size = size / s.elements(); std::size_t element_size = (s.elements() == 0 ? 4 : (size / s.elements()));
live_range& segment = interval->segment; live_range& segment = interval->segment;
int vn = segment.vn; int vn = segment.vn;
std::priority_queue<live_range*, std::vector<live_range*>, ordering> conflict_queue; std::priority_queue<live_range*, std::vector<live_range*>, ordering> conflict_queue;
...@@ -249,8 +249,8 @@ void memory_coloring_impl::verify() ...@@ -249,8 +249,8 @@ void memory_coloring_impl::verify()
if(segment.begin == invalid_offset) if(segment.begin == invalid_offset)
{ {
if(!interval.is_live_on_entry) // if(!interval.is_live_on_entry)
MIGRAPHX_THROW("interval is not live on entry"); // MIGRAPHX_THROW("interval is not live on entry");
continue; continue;
} }
......
...@@ -32,14 +32,6 @@ void validate_pass(module& mod, const pass& p, tracer trace) ...@@ -32,14 +32,6 @@ void validate_pass(module& mod, const pass& p, tracer trace)
trace(); trace();
#endif #endif
} }
void run_pass(module& mod, const pass& p, tracer trace)
{
trace("Module: ", mod.name(), ", Pass: ", p.name());
assert(mod.validate() == mod.end());
p.apply(mod);
trace(mod);
validate_pass(mod, p, trace);
}
void run_pass(program& prog, const pass& p, tracer trace) void run_pass(program& prog, const pass& p, tracer trace)
{ {
trace("Pass: ", p.name()); trace("Pass: ", p.name());
...@@ -47,11 +39,52 @@ void run_pass(program& prog, const pass& p, tracer trace) ...@@ -47,11 +39,52 @@ void run_pass(program& prog, const pass& p, tracer trace)
trace(prog); trace(prog);
} }
struct module_pm : module_pass_manager
{
module* mod;
program* prog;
tracer* t;
module_pm(module* pmod = nullptr, program* pprog = nullptr, tracer* pt = nullptr)
: mod(pmod), prog(pprog), t(pt)
{
}
template <class... Ts>
void trace(Ts&&... xs) const
{
assert(t);
(*t)(xs...);
}
virtual module& get_module() override
{
assert(mod);
return *mod;
}
virtual module* create_module(const std::string& name) override
{
assert(prog);
return prog->create_module(name);
}
virtual void run_pass(const pass& p) override
{
assert(mod);
trace("Module: ", mod->name(), ", Pass: ", p.name());
assert(mod->validate() == mod->end());
p.apply(*this);
trace(*mod);
validate_pass(*mod, p, *t);
}
};
module& get_module(module_pass_manager& mpm) { return mpm.get_module(); }
void run_passes(module& mod, const std::vector<pass>& passes, tracer trace) void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
{ {
for(const auto& p : passes) for(const auto& p : passes)
{ {
run_pass(mod, p, trace); module_pm{&mod, nullptr, &trace}.run_pass(p);
} }
} }
...@@ -62,7 +95,9 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace) ...@@ -62,7 +95,9 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
auto mods = prog.get_modules(); auto mods = prog.get_modules();
for(const auto& mod : reverse(mods)) for(const auto& mod : reverse(mods))
{ {
run_pass(*mod, p, trace); if(mod->bypass())
continue;
module_pm{mod, &prog, &trace}.run_pass(p);
} }
run_pass(prog, p, trace); run_pass(prog, p, trace);
} }
......
#include <migraphx/gpu/preallocate_param.hpp> #include <migraphx/preallocate_param.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/module.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
void preallocate_param::apply(module& p) const void preallocate_param::apply(module& m) const
{ {
for(auto ins : iterator_for(p)) auto last = std::prev(m.end());
for(auto ins : iterator_for(m))
{ {
if(ins->name() != "@param") if(ins->name() != "@param")
continue; continue;
if(param != any_cast<builtin::param>(ins->get_operator()).parameter) if(param != any_cast<builtin::param>(ins->get_operator()).parameter)
continue; continue;
std::string id = p.name() + ":" + param; std::string id = m.name() + ":" + param;
auto r = p.insert_instruction(ins, hip_allocate_memory{ins->get_shape(), id}); auto r = m.insert_instruction(ins, model.preallocate(ins->get_shape(), id));
p.replace_instruction(ins, r); m.replace_instruction(ins, r);
m.move_instruction(ins, m.end());
} }
m.remove_instructions(std::next(last), m.end());
} }
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/iterator.hpp>
#include <migraphx/algorithm.hpp> #include <migraphx/algorithm.hpp>
#include <migraphx/output_iterator.hpp> #include <migraphx/output_iterator.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
...@@ -25,6 +26,8 @@ ...@@ -25,6 +26,8 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
using milliseconds = std::chrono::duration<double, std::milli>;
struct program_impl struct program_impl
{ {
// A map is used to keep references to modules of the program // A map is used to keep references to modules of the program
...@@ -181,14 +184,16 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -181,14 +184,16 @@ std::vector<argument> generic_eval(const module* mod,
context& ctx, context& ctx,
std::unordered_map<std::string, argument> params, std::unordered_map<std::string, argument> params,
std::unordered_map<instruction_ref, argument> results, std::unordered_map<instruction_ref, argument> results,
F trace) F make_trace)
{ {
assert(mod->validate() == mod->end()); assert(mod->validate() == mod->end());
results.reserve(mod->size() * 2); results.reserve(mod->size() * 2);
std::vector<argument> values; std::vector<argument> values;
values.reserve(16); values.reserve(16);
auto trace = make_trace(mod);
for(auto ins : iterator_for(*mod)) for(auto ins : iterator_for(*mod))
{ {
assert(results.find(ins) == results.end());
const auto& name = ins->name(); const auto& name = ins->name();
if(name == "@literal") if(name == "@literal")
{ {
...@@ -237,25 +242,17 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -237,25 +242,17 @@ std::vector<argument> generic_eval(const module* mod,
const auto& mod_args = ins->module_inputs(); const auto& mod_args = ins->module_inputs();
auto module_eval = [&](module_ref smod, auto module_eval = [&](module_ref smod,
const std::unordered_map<std::string, argument>& inputs) { const std::unordered_map<std::string, argument>& inputs) {
return generic_eval(smod, ctx, inputs, results, trace); auto ssctx = ctx;
return generic_eval(smod, ssctx, inputs, results, make_trace);
}; };
if(not mod_args.empty()) results.emplace(ins, trace(ins, [&] {
{ return ins->normalized_operator().compute(
results.emplace(ins, trace(ins, [&] { ctx, ins->get_shape(), values, mod_args, module_eval);
return ins->normalized_operator().compute( }));
values, mod_args, module_eval);
}));
}
else
{
results.emplace(ins, trace(ins, [&] {
return ins->normalized_operator().compute(
ctx, ins->get_shape(), values);
}));
}
} }
assert(results.find(ins) != results.end()); assert(results.find(ins) != results.end());
assert(results.at(ins).get_shape() == ins->get_shape());
} }
return {results.at(std::prev(mod->end()))}; return {results.at(std::prev(mod->end()))};
} }
...@@ -264,46 +261,67 @@ template <class F> ...@@ -264,46 +261,67 @@ template <class F>
std::vector<argument> generic_eval(const program& p, std::vector<argument> generic_eval(const program& p,
context& ctx, context& ctx,
std::unordered_map<std::string, argument> params, std::unordered_map<std::string, argument> params,
F trace) F make_trace)
{ {
const module* mm = p.get_main_module(); const module* mm = p.get_main_module();
return generic_eval(mm, ctx, params, {}, trace); return generic_eval(mm, ctx, params, {}, make_trace);
} }
std::vector<argument> program::eval(parameter_map params) const std::vector<argument> program::eval(parameter_map params) const
{ {
auto& ctx = this->impl->ctx; auto& ctx = this->impl->ctx;
#ifndef NDEBUG #ifndef NDEBUG
auto sctx = ctx; auto with_check_context = [&](auto f) {
auto check_context = [&](auto f) { return [=, &ctx](auto&&) {
assert(is_shared(ctx, sctx)); auto sctx = std::make_shared<context>(ctx);
auto x = f(); auto check_context = [=, &ctx](auto g) {
sctx = ctx; assert(is_shared(ctx, *sctx));
return x; auto x = g();
*sctx = ctx;
return x;
};
return [=](auto&&... xs) { return f(xs..., check_context); };
};
}; };
#else #else
auto check_context = [](auto f) { return f(); }; auto with_check_context = [](auto f) {
return [=](auto&&) {
return [=](auto&&... xs) { return f(xs..., [](auto g) { return g(); }); };
};
};
#endif #endif
auto trace_level = value_of(MIGRAPHX_TRACE_EVAL{}); auto trace_level = value_of(MIGRAPHX_TRACE_EVAL{});
if(trace_level > 0) if(trace_level > 0)
{ {
return generic_eval(*this, ctx, std::move(params), [&](auto& ins, auto f) { return generic_eval(*this,
ctx.finish(); ctx,
std::cout << "Run instruction: "; std::move(params),
this->debug_print(ins); with_check_context([&](auto& ins, auto f, auto&& check_context) {
auto result = check_context(f); ctx.finish();
ctx.finish(); std::cout << "Run instruction: ";
if(trace_level > 1 and ins->name().front() != '@' and ins->name() != "load") this->debug_print(ins);
std::cout << "Ouput: " << result << std::endl; timer t{};
return result; auto result = check_context(f);
}); double t1 = t.record<milliseconds>();
ctx.finish();
double t2 = t.record<milliseconds>();
std::cout << "Time: " << t1 << "ms, " << t2 << "ms" << std::endl;
if(trace_level > 1 and ins->name().front() != '@' and
ins->name() != "load")
std::cout << "Output: " << result << std::endl;
return result;
}));
} }
else else
{ {
return generic_eval( return generic_eval(*this,
*this, ctx, std::move(params), [&](auto&, auto f) { return check_context(f); }); ctx,
std::move(params),
with_check_context([&](auto&, auto f, auto&& check_context) {
return check_context(f);
}));
} }
} }
...@@ -478,10 +496,17 @@ double common_average(const std::vector<double>& v) ...@@ -478,10 +496,17 @@ double common_average(const std::vector<double>& v)
return total / std::distance(v.begin() + n, v.end() - n); return total / std::distance(v.begin() + n, v.end() - n);
} }
std::string perf_group(const operation& op)
{
auto attr = op.attributes();
if(attr.contains("group"))
return attr.at("group").to<std::string>();
return op.name();
}
void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) const void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) const
{ {
using milliseconds = std::chrono::duration<double, std::milli>; auto& ctx = this->impl->ctx;
auto& ctx = this->impl->ctx;
// Run once by itself // Run once by itself
eval(params); eval(params);
ctx.finish(); ctx.finish();
...@@ -498,21 +523,22 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) ...@@ -498,21 +523,22 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
std::sort(total_vec.begin(), total_vec.end()); std::sort(total_vec.begin(), total_vec.end());
std::unordered_map<instruction_ref, std::vector<double>> ins_vec; std::unordered_map<instruction_ref, std::vector<double>> ins_vec;
// Fill the map // Fill the map
generic_eval(*this, ctx, params, [&](auto ins, auto) { generic_eval(*this, ctx, params, always([&](auto ins, auto) {
ins_vec[ins].reserve(n); ins_vec[ins].reserve(n);
return argument{}; return argument{ins->get_shape(), nullptr};
}); }));
// Run and time each instruction // Run and time each instruction
for(std::size_t i = 0; i < n; i++) for(std::size_t i = 0; i < n; i++)
{ {
generic_eval(*this, ctx, params, [&](auto ins, auto f) { generic_eval(*this, ctx, params, always([&](auto ins, auto f) {
argument result; argument result;
ins_vec[ins].push_back(time<milliseconds>([&] { ins_vec[ins].push_back(time<milliseconds>([&] {
result = f(); result = f();
ctx.finish(); ctx.finish();
})); }));
return result; return result;
}); }));
} }
for(auto&& p : ins_vec) for(auto&& p : ins_vec)
std::sort(p.second.begin(), p.second.end()); std::sort(p.second.begin(), p.second.end());
...@@ -533,7 +559,7 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) ...@@ -533,7 +559,7 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
for(auto&& p : ins_vec) for(auto&& p : ins_vec)
{ {
double avg = common_average(p.second); double avg = common_average(p.second);
op_times[p.first->name()] += avg; op_times[perf_group(p.first->get_operator())] += avg;
total_instruction_time += avg; total_instruction_time += avg;
} }
double calculate_overhead_time = total_time - total_instruction_time; double calculate_overhead_time = total_time - total_instruction_time;
...@@ -585,7 +611,7 @@ void program::debug_print(instruction_ref ins) const ...@@ -585,7 +611,7 @@ void program::debug_print(instruction_ref ins) const
{ {
std::unordered_map<instruction_ref, std::string> names; std::unordered_map<instruction_ref, std::string> names;
if(std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](const auto& pp) { if(std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](const auto& pp) {
return (pp.second.end() == ins); return is_end(pp.second.end(), ins);
})) }))
{ {
std::cout << "End instruction" << std::endl; std::cout << "End instruction" << std::endl;
...@@ -641,7 +667,9 @@ void program::print_cpp(std::ostream& os) const ...@@ -641,7 +667,9 @@ void program::print_cpp(std::ostream& os) const
void program::dry_run(std::unordered_map<std::string, argument> params) const void program::dry_run(std::unordered_map<std::string, argument> params) const
{ {
auto& ctx = this->impl->ctx; auto& ctx = this->impl->ctx;
generic_eval(*this, ctx, std::move(params), [](auto&&...) { return argument{}; }); generic_eval(*this, ctx, std::move(params), always([](auto ins, auto&&...) {
return argument{ins->get_shape(), nullptr};
}));
} }
void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const
...@@ -741,6 +769,22 @@ void program::remove_module(const std::string& name) ...@@ -741,6 +769,22 @@ void program::remove_module(const std::string& name)
impl->modules.at(name).end(), impl->modules.at(name).end(),
[&](auto&& ins) { return references_instruction(impl->modules, ins, name); }) && [&](auto&& ins) { return references_instruction(impl->modules, ins, name); }) &&
"Instruction referenced in another module"); "Instruction referenced in another module");
// if an instruction has an input out side of the current module, need to remove
// the instruction from its input's outputs
auto& mod = impl->modules.at(name);
for(auto ins : iterator_for(mod))
{
auto inputs = ins->inputs();
for(auto in : inputs)
{
if(not mod.has_instruction(in))
{
in->remove_output(ins);
}
}
}
impl->modules.erase(name); impl->modules.erase(name);
} }
......
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from .backend import is_compatible, prepare, run, supports_device
...@@ -82,9 +82,6 @@ class MIGraphXBackend(Backend): ...@@ -82,9 +82,6 @@ class MIGraphXBackend(Backend):
elif isinstance(model, migraphx.program): elif isinstance(model, migraphx.program):
return MIGraphXBackendRep(model, cls._input_names) return MIGraphXBackendRep(model, cls._input_names)
elif isinstance(model, (str, bytes)): elif isinstance(model, (str, bytes)):
for k, v in kwargs.items():
if hasattr(options, k):
setattr(options, k, v)
if device is not None and not cls.supports_device(device): if device is not None and not cls.supports_device(device):
raise RuntimeError( raise RuntimeError(
"Incompatible device expected '{0}', got '{1}'".format( "Incompatible device expected '{0}', got '{1}'".format(
......
File mode changed from 100644 to 100755
...@@ -325,12 +325,14 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -325,12 +325,14 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
unsigned int default_dim_value, unsigned int default_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims, std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
bool skip_unknown_operators, bool skip_unknown_operators,
bool print_program_on_error) { bool print_program_on_error,
int64_t max_loop_iterations) {
migraphx::onnx_options options; migraphx::onnx_options options;
options.default_dim_value = default_dim_value; options.default_dim_value = default_dim_value;
options.map_input_dims = map_input_dims; options.map_input_dims = map_input_dims;
options.skip_unknown_operators = skip_unknown_operators; options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error; options.print_program_on_error = print_program_on_error;
options.max_loop_iterations = max_loop_iterations;
return migraphx::parse_onnx(filename, options); return migraphx::parse_onnx(filename, options);
}, },
"Parse onnx file", "Parse onnx file",
...@@ -338,7 +340,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -338,7 +340,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("default_dim_value") = 1, py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(), py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("skip_unknown_operators") = false, py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false); py::arg("print_program_on_error") = false,
py::arg("max_loop_iterations") = 10);
m.def("parse_onnx_buffer", m.def("parse_onnx_buffer",
[](const std::string& onnx_buffer, [](const std::string& onnx_buffer,
......
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/quantize_int8.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/op/convert.hpp>
#include <migraphx/op/clip.hpp>
#include <migraphx/op/round.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/target.hpp> #include <migraphx/target.hpp>
#include <utility>
#include <set>
#include <iomanip>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <fstream> #include <set>
#include <algorithm>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
instruction_ref insert_quant_ins(module& modl,
instruction_ref& ins,
shape::type_t type,
std::unordered_map<instruction_ref, instruction_ref>& map_ins,
float scale = 1.0f,
float shift = 0.0f)
{
if(map_ins.count(ins) > 0)
{
return map_ins[ins];
}
if(ins->name() == "undefined")
{
return ins;
}
assert(ins->get_shape().type() == shape::float_type or
ins->get_shape().type() == shape::double_type or
ins->get_shape().type() == shape::int32_type or
ins->get_shape().type() == shape::half_type);
instruction_ref quant_ins{};
auto insert_loc = std::next(ins);
if(type == shape::int8_type)
{
auto scaled_ins = ins;
if(scale != 1.0f)
{
auto float_ins = scaled_ins;
if(scaled_ins->get_shape().type() != shape::float_type)
{
float_ins = modl.insert_instruction(
insert_loc,
make_op("convert", {{"target_type", to_value(shape::float_type)}}),
scaled_ins);
}
std::vector<float> vec_scale(scaled_ins->get_shape().elements(), scale);
auto l_scale = modl.add_literal(literal(float_ins->get_shape(), vec_scale));
scaled_ins = modl.insert_instruction(insert_loc, make_op("mul"), l_scale, float_ins);
}
auto shifted_ins = scaled_ins;
if(shift != 0.0f)
{
auto float_ins = shifted_ins;
if(shifted_ins->get_shape().type() != shape::float_type)
{
float_ins = modl.insert_instruction(
insert_loc,
make_op("convert", {{"target_type", to_value(shape::float_type)}}),
shifted_ins);
}
std::vector<float> vec_shift(shifted_ins->get_shape().elements(), shift);
auto l_shift = modl.add_literal(literal(float_ins->get_shape(), vec_shift));
shifted_ins = modl.insert_instruction(insert_loc, make_op("add"), l_shift, float_ins);
}
auto rounded_ins = modl.insert_instruction(insert_loc, make_op("round"), shifted_ins);
auto rounded_lens = rounded_ins->get_shape().lens();
auto max_clip = modl.add_literal(127.0f);
auto min_clip = modl.add_literal(-128.0f);
max_clip = modl.insert_instruction(
insert_loc, make_op("multibroadcast", {{"output_lens", rounded_lens}}), max_clip);
min_clip = modl.insert_instruction(
insert_loc, make_op("multibroadcast", {{"output_lens", rounded_lens}}), min_clip);
auto clipped_ins =
modl.insert_instruction(insert_loc, make_op("clip"), rounded_ins, min_clip, max_clip);
quant_ins = modl.insert_instruction(
insert_loc, make_op("convert", {{"target_type", type}}), clipped_ins);
}
else
{
quant_ins =
modl.insert_instruction(insert_loc, make_op("convert", {{"target_type", type}}), ins);
}
map_ins[ins] = quant_ins;
return quant_ins;
}
// This function is to convert any instructions specified in the input // This function is to convert any instructions specified in the input
// from double or float to float16 by inserting a convert operator. // from double or float to float16 by inserting a convert operator.
// For the conversion, there could be cases of overflowing, but it // For the conversion, there could be cases of overflowing, but it
...@@ -119,337 +30,14 @@ instruction_ref insert_quant_ins(module& modl, ...@@ -119,337 +30,14 @@ instruction_ref insert_quant_ins(module& modl,
// truncate of the input to get the fp16. // truncate of the input to get the fp16.
void quantize_fp16(program& prog, const std::vector<std::string>& ins_names) void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
{ {
auto* mm = prog.get_main_module(); run_passes(prog,
std::unordered_map<instruction_ref, instruction_ref> map_fp16; {quantize_fp16_pass{ins_names},
for(auto ins : iterator_for(*mm)) eliminate_common_subexpression{},
{ dead_code_elimination{},
if(ins->name() == "@return") simplify_reshapes{},
break; dead_code_elimination{},
simplify_qdq{},
// all indicates every instruction is converted dead_code_elimination{}});
if((not contains(ins_names, "all")) and (not contains(ins_names, ins->name())))
{
continue;
}
shape::type_t orig_type = ins->get_shape().type();
// process all inputs, if input is a fp32 or fp64, convert it
// to a fp16 by adding a convert operator.
auto inputs = ins->inputs();
std::vector<instruction_ref> converted_inputs;
for(auto input : inputs)
{
auto s = input->get_shape();
if(s.type() == shape::float_type || s.type() == shape::double_type)
{
// if the input is a convert operator, uses its input
// as its current input
instruction_ref input_fp16{};
if(input->name() == "convert" and
input->inputs().front()->get_shape().type() == shape::half_type)
{
input_fp16 = input->inputs().front();
}
else
{
input_fp16 = insert_quant_ins(*mm, input, shape::half_type, map_fp16);
}
converted_inputs.push_back(input_fp16);
}
else
{
converted_inputs.push_back(input);
}
}
// no change for the input, go to the next instruction
if(inputs == converted_inputs)
{
continue;
}
auto op = ins->get_operator();
auto ins_shape = compute_shape(op, converted_inputs);
if(ins_shape.type() != orig_type)
{
// check the dead code case to avoid assert
bool output_empty = ins->outputs().empty();
auto ins_orig_type = mm->insert_instruction(
std::next(ins), make_op("convert", {{"target_type", orig_type}}), ins);
if(!output_empty)
{
mm->replace_instruction(ins, ins_orig_type);
}
}
mm->replace_instruction(ins, op, converted_inputs);
}
}
static void ins_quantize_int8(module& modl,
instruction_ref ins,
std::vector<instruction_ref>& converted_inputs,
const std::vector<std::pair<float, float>>& ins_quant_params)
{
auto orig_type = ins->get_shape().type();
auto inputs = ins->inputs();
if(ins->name() == "dot")
{
auto dot_op = any_cast<op::dot>(ins->get_operator());
float new_alpha = dot_op.alpha / (ins_quant_params[0].first * ins_quant_params[1].first);
float new_beta = dot_op.beta;
// We need additional checking about the quant_alpha value. If
// abs(quant_alpha) > 50 (some tmp value set here), we can convert
// it to an integer as the new_alpha in the quant_dot
float threshold = 50.0f;
if(fabs(new_alpha) >= threshold && fabs(new_beta) >= threshold)
{
int32_t quant_alpha = static_cast<int32_t>(std::round(new_alpha));
int32_t quant_beta = static_cast<int32_t>(std::round(new_beta));
if(shape::int32_type == orig_type)
{
modl.replace_instruction(
ins,
make_op("quant_dot", {{"alpha", quant_alpha}, {"beta", quant_beta}}),
converted_inputs);
}
else
{
auto quant_dot = modl.insert_instruction(
ins,
make_op("quant_dot", {{"alpha", quant_alpha}, {"beta", quant_beta}}),
converted_inputs);
modl.replace_instruction(
ins, make_op("convert", {{"target_type", to_value(orig_type)}}), quant_dot);
}
}
// either alpha or beta cannot be quantized because of too big
// relative rounding error
else
{
if(converted_inputs.size() == 3)
{
converted_inputs.pop_back();
}
auto q_dot = modl.insert_instruction(
ins, make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), converted_inputs);
auto f_dot = modl.insert_instruction(
ins, make_op("convert", {{"target_type", to_value(shape::float_type)}}), q_dot);
auto c_shape = q_dot->get_shape();
std::vector<float> vec_alpha(c_shape.elements(), new_alpha);
auto l_alpha =
modl.add_literal(literal({shape::float_type, c_shape.lens()}, vec_alpha));
if(inputs.size() == 3 and dot_op.beta != 0.0f)
{
auto alpha_ab = modl.insert_instruction(ins, make_op("mul"), l_alpha, f_dot);
std::vector<float> vec_beta(c_shape.elements(), dot_op.beta);
auto l_beta =
modl.add_literal(literal({shape::float_type, c_shape.lens()}, vec_beta));
instruction_ref beta_c{};
if(orig_type != shape::float_type)
{
auto fp32_c = modl.insert_instruction(
ins,
make_op("convert", {{"target_type", to_value(shape::float_type)}}),
inputs.back());
beta_c = modl.insert_instruction(ins, make_op("mul"), l_beta, fp32_c);
}
else
{
beta_c = modl.insert_instruction(ins, make_op("mul"), l_beta, inputs.back());
}
if(orig_type == shape::float_type)
{
modl.replace_instruction(ins, make_op("add"), alpha_ab, beta_c);
}
else
{
auto f_res = modl.insert_instruction(ins, make_op("add"), alpha_ab, beta_c);
modl.replace_instruction(
ins, make_op("convert", {{"target_type", to_value(orig_type)}}), f_res);
}
}
else
{
if(orig_type == shape::float_type)
{
modl.replace_instruction(ins, make_op("mul"), l_alpha, f_dot);
}
else
{
auto alpha_ab = modl.insert_instruction(ins, make_op("mul"), l_alpha, f_dot);
modl.replace_instruction(
ins, make_op("convert", {{"target_type", to_value(orig_type)}}), alpha_ab);
}
}
}
}
else if(ins->name() == "convolution")
{
// Current MIOpen convolution does not support alpha and beta,
// so we need a separate multiply to adjust the output
auto conv_op = any_cast<op::convolution>(ins->get_operator());
auto padding = conv_op.padding;
auto stride = conv_op.stride;
auto dilation = conv_op.dilation;
auto padding_mode = conv_op.padding_mode;
auto group = conv_op.group;
auto adjust_factor = 1.0f / (ins_quant_params[0].first * ins_quant_params[1].first);
auto quant_conv = modl.insert_instruction(
ins,
op::quant_convolution{padding, stride, dilation, padding_mode, group},
converted_inputs);
float threshold = 50.0f;
std::vector<float> vec_factor(quant_conv->get_shape().elements(), adjust_factor);
if(quant_conv->get_shape().type() == orig_type and adjust_factor >= threshold)
{
auto l_factor = modl.add_literal(
literal(quant_conv->get_shape(), vec_factor.begin(), vec_factor.end()));
modl.replace_instruction(ins, make_op("mul"), quant_conv, l_factor);
}
// convert quant_conv output to float type, multiply the factor and
// conver back to original type
else
{
auto float_conv = modl.insert_instruction(
ins,
make_op("convert", {{"target_type", to_value(shape::float_type)}}),
quant_conv);
auto l_factor = modl.add_literal(literal(float_conv->get_shape(), vec_factor));
if(orig_type == shape::float_type)
{
modl.replace_instruction(ins, make_op("mul"), l_factor, float_conv);
}
else
{
auto adjusted_conv =
modl.insert_instruction(ins, make_op("mul"), l_factor, float_conv);
modl.replace_instruction(
ins, make_op("convert", {{"target_type", to_value(orig_type)}}), adjusted_conv);
}
}
}
else
{
MIGRAPHX_THROW("QUANTIZE_INT8: does not support operator " + ins->name());
}
}
// int8 quantization is different from fp16 since int8 can only handle value
// -128 ~ 127. To convert the float or double to int8, we need a scale and
// a shift, then the convert can be done as v_int8 = fp * scale + shift.
// To simplify the changes, we consider shift as 0.0f for now.
void quantize_int8_impl(program& prog,
const std::vector<std::pair<float, float>>& quant_params,
const std::vector<std::string>& ins_names)
{
if(enabled(MIGRAPHX_INT8_QUANTIZATION_PARAMS{}))
{
for(std::size_t i = 0; i < quant_params.size(); ++i)
{
auto param = quant_params.at(i);
std::cout << "ins_index = " << i << ", scale = " << param.first
<< ", shift = " << param.second << std::endl;
}
std::cout << std::endl;
}
// For now, we only support the int8 quantization of gemm and convolution
std::set<std::string> op_names = {"convolution", "dot"};
std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end());
if(!std::includes(
op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
{
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
}
auto* mm = prog.get_main_module();
std::size_t quant_param_index = 0;
std::unordered_map<instruction_ref, instruction_ref> map_quant_ins;
std::unordered_map<instruction_ref, std::size_t> map_ins_index;
for(auto ins : iterator_for(*mm))
{
if(ins->name() == "@return")
break;
if(not contains(ins_names, ins->name()))
{
continue;
}
// for the dot operator, there could be 2 or 3 input arguments
// if the 3rd argument is available, convert it to an int32.
std::vector<instruction_ref> converted_inputs;
// process all inputs, if input is a fp32 or fp64, convert it
// to a int8 type by adding a convert operator and replace
// the operator with the corresponding int8 version
auto inputs = ins->inputs();
std::vector<std::pair<float, float>> ins_quant_params;
for(auto input : inputs)
{
// calculate the index of each instruction to be quantized
std::size_t ins_index =
(map_ins_index.count(input) > 0) ? map_ins_index[input] : quant_param_index++;
map_ins_index[input] = ins_index;
auto param = quant_params[map_ins_index[input]];
ins_quant_params.push_back(param);
// In general, the target_type is int8, but for the dot
// operation, if it has 3 inputs, then the last one should
// be converted to int32_type
shape::type_t quant_type = shape::int8_type;
if((ins->name() == "dot") and (inputs.size() == 3) and (input == inputs.back()))
{
quant_type = shape::int32_type;
}
auto s = input->get_shape();
if((s.type() == shape::float_type or s.type() == shape::double_type or
s.type() == shape::half_type or s.type() == shape::int32_type) and
s.type() != quant_type)
{
// if the input is a convert operator, uses its input
// as its current input
instruction_ref quant_input{};
if(input->name() == "convert" and
input->inputs().front()->get_shape().type() == quant_type)
{
quant_input = input->inputs().front();
// the scale in this case is not used, so tune the scale
// to 1.0f for this parameter
ins_quant_params.back() = std::pair<float, float>(1.0f, 0.0f);
}
else
{
quant_input = insert_quant_ins(
*mm, input, quant_type, map_quant_ins, param.first, param.second);
}
converted_inputs.push_back(quant_input);
}
else
{
converted_inputs.push_back(input);
}
}
// no change for the input, go to the next instruction
if(inputs == converted_inputs)
{
continue;
}
ins_quantize_int8(*mm, ins, converted_inputs, ins_quant_params);
}
if(quant_param_index != quant_params.size())
{
MIGRAPHX_THROW("QUANTIZE_INT8: number of scales does not match");
}
} }
void quantize_int8(program& prog, void quantize_int8(program& prog,
...@@ -457,87 +45,14 @@ void quantize_int8(program& prog, ...@@ -457,87 +45,14 @@ void quantize_int8(program& prog,
const std::vector<parameter_map>& calibration, const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names) const std::vector<std::string>& ins_names)
{ {
// insert capture operator std::set<std::string> op_names = {"convolution", "dot"};
auto cap_prog = prog;
auto int8_quant_params = capture_arguments(cap_prog, t, ins_names);
// use the calibration data to compute the quantization scale
cap_prog.compile(t);
// use all calibration data to run the program to calculate the
// quantization scale and shift
for(auto&& arg : calibration)
{
parameter_map m;
for(auto&& x : cap_prog.get_parameter_shapes())
{
if(arg.count(x.first) > 0)
{
assert(x.second == arg.at(x.first).get_shape());
m[x.first] = t.copy_to(arg.at(x.first));
}
else
{
m[x.first] = t.allocate(x.second);
}
}
cap_prog.eval(m);
}
quantize_int8_impl(prog, *int8_quant_params, ins_names);
}
// For the input of each input argument, we need to insert a
// capture operator to compute the scale and shift
std::size_t capture_arguments(program& prog,
const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func)
{
auto* mm = prog.get_main_module();
size_t num_quant_params = 0;
// the int8 quantization only support dot and convolution
std::set<std::string> op_names = {"dot", "convolution"};
std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end()); std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end());
if(!std::includes( if(!std::includes(
op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end())) op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
{ {
MIGRAPHX_THROW("CAPTURE_ARGUMENTS: input operator is not supported"); MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
}
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(*mm))
{
if(not contains(ins_names, ins->name()))
{
continue;
}
auto inputs = ins->inputs();
std::vector<instruction_ref> new_args;
for(auto input : inputs)
{
instruction_ref new_ins{};
if(ins_map.count(input) > 0)
{
new_ins = ins_map[input];
}
else
{
new_ins = mm->insert_instruction(
std::next(input), op::capture{num_quant_params++, func}, input);
ins_map[input] = new_ins;
}
new_args.push_back(new_ins);
}
instruction::replace(ins, ins->get_operator(), ins->get_shape(), new_args);
} }
return num_quant_params;
}
std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments_impl(program& prog, const target& t, const std::vector<std::string>& ins_names)
{
std::shared_ptr<std::vector<std::pair<float, float>>> int8_quant_params = std::shared_ptr<std::vector<std::pair<float, float>>> int8_quant_params =
std::make_shared<std::vector<std::pair<float, float>>>(); std::make_shared<std::vector<std::pair<float, float>>>();
std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>(); std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>();
...@@ -545,7 +60,6 @@ capture_arguments_impl(program& prog, const target& t, const std::vector<std::st ...@@ -545,7 +60,6 @@ capture_arguments_impl(program& prog, const target& t, const std::vector<std::st
auto calc_quant_params = [int8_quant_params, max_abs_vals, &t](std::size_t ins_index, auto calc_quant_params = [int8_quant_params, max_abs_vals, &t](std::size_t ins_index,
std::vector<argument> args) { std::vector<argument> args) {
std::pair<float, float> param_pair{64.0f, 0.0f}; std::pair<float, float> param_pair{64.0f, 0.0f};
// scale and shift is need for only int8 type, and we do not // scale and shift is need for only int8 type, and we do not
// consider shift, so set shift to 0 // consider shift, so set shift to 0
std::vector<float> vec_val; std::vector<float> vec_val;
...@@ -568,12 +82,56 @@ capture_arguments_impl(program& prog, const target& t, const std::vector<std::st ...@@ -568,12 +82,56 @@ capture_arguments_impl(program& prog, const target& t, const std::vector<std::st
int8_quant_params->at(ins_index) = param_pair; int8_quant_params->at(ins_index) = param_pair;
}; };
auto num_params = capture_arguments(prog, ins_names, calc_quant_params); // pass to add capture argument op
std::size_t param_num = 0;
run_passes(prog, {capture_arguments_pass{ins_names, calc_quant_params, &param_num}});
int8_quant_params->resize(param_num, std::pair<float, float>(64.0f, 0.0f));
max_abs_vals->resize(param_num, 0.0f);
// use the calibration data to compute the quantization scale
auto capture_prog = prog;
capture_prog.compile(t);
// use all calibration data to run the program to calculate the
// quantization scale and shift
for(auto&& arg : calibration)
{
parameter_map m;
for(auto&& x : capture_prog.get_parameter_shapes())
{
if(arg.count(x.first) > 0)
{
assert(x.second == arg.at(x.first).get_shape());
m[x.first] = t.copy_to(arg.at(x.first));
}
else
{
m[x.first] = t.allocate(x.second);
}
}
capture_prog.eval(m);
}
int8_quant_params->resize(num_params, std::pair<float, float>(64.0f, 0.0f)); // print the quantization parameters in only the main module
max_abs_vals->resize(num_params, 0.0f); if(enabled(MIGRAPHX_INT8_QUANTIZATION_PARAMS{}))
{
for(std::size_t i = 0; i < int8_quant_params->size(); ++i)
{
auto param = int8_quant_params->at(i);
std::cout << "ins_index = " << i << ", scale = " << param.first
<< ", shift = " << param.second << std::endl;
}
std::cout << std::endl;
}
return int8_quant_params; run_passes(prog,
{quantize_int8_pass{ins_names, *int8_quant_params},
eliminate_common_subexpression{},
dead_code_elimination{},
simplify_reshapes{},
dead_code_elimination{},
simplify_qdq{},
dead_code_elimination{}});
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment