Commit 00d5d880 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mi100_opts

parents 00d90ca8 f60c3815
...@@ -7,6 +7,12 @@ ...@@ -7,6 +7,12 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class T>
auto equal_to(const T& x)
{
return [&](const T& y) { return std::equal_to<T>{}(x, y); };
}
instruction::instruction(operation o, shape r, std::vector<instruction_ref> args) instruction::instruction(operation o, shape r, std::vector<instruction_ref> args)
: op(std::move(o)), result(std::move(r)), arguments(std::move(args)) : op(std::move(o)), result(std::move(r)), arguments(std::move(args))
{ {
...@@ -133,8 +139,13 @@ const std::vector<instruction_ref>& instruction::outputs() const { return output ...@@ -133,8 +139,13 @@ const std::vector<instruction_ref>& instruction::outputs() const { return output
bool operator==(const instruction& x, const instruction& y) bool operator==(const instruction& x, const instruction& y)
{ {
if(std::tie(x.result, x.op, x.arguments, x.module_args) != if(not std::equal(x.arguments.begin(),
std::tie(y.result, y.op, y.arguments, y.module_args)) x.arguments.end(),
y.arguments.begin(),
y.arguments.end(),
std::equal_to<instruction_ref>{}))
return false;
if(std::tie(x.result, x.op, x.module_args) != std::tie(y.result, y.op, y.module_args))
return false; return false;
if(x.name() == "@literal") if(x.name() == "@literal")
return x.lit == y.lit; return x.lit == y.lit;
...@@ -151,7 +162,7 @@ bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref); ...@@ -151,7 +162,7 @@ bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref);
void instruction::add_output(instruction_ref ins) void instruction::add_output(instruction_ref ins)
{ {
if(std::find(output.begin(), output.end(), ins) == output.end()) if(std::find_if(output.begin(), output.end(), equal_to(ins)) == output.end())
output.push_back(ins); output.push_back(ins);
} }
...@@ -256,8 +267,8 @@ void instruction::replace(std::vector<instruction_ref> args, std::vector<module_ ...@@ -256,8 +267,8 @@ void instruction::replace(std::vector<instruction_ref> args, std::vector<module_
void instruction::replace_argument(instruction_ref old, instruction_ref new_ins) void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
{ {
assert(std::any_of(arguments.begin(), arguments.end(), [&](auto i) { return i == old; })); assert(std::any_of(arguments.begin(), arguments.end(), equal_to(old)));
std::replace(arguments.begin(), arguments.end(), old, new_ins); std::replace_if(arguments.begin(), arguments.end(), equal_to(old), new_ins);
old->remove_output(*this); old->remove_output(*this);
} }
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/time.hpp> #include <migraphx/time.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/iterator.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
...@@ -30,7 +31,7 @@ struct module_impl ...@@ -30,7 +31,7 @@ struct module_impl
bool contains(instruction_ref ins) const bool contains(instruction_ref ins) const
{ {
if(ins == instructions.end()) if(is_end(ins, instructions.end()))
return false; return false;
return instruction_set.count(std::addressof(*ins)) > 0; return instruction_set.count(std::addressof(*ins)) > 0;
} }
...@@ -149,14 +150,7 @@ void module::assign(const module& m) ...@@ -149,14 +150,7 @@ void module::assign(const module& m)
} }
else else
{ {
if(module_args.empty()) copy_ins = add_instruction(ins->get_operator(), copy_inputs, module_args);
{
copy_ins = add_instruction(ins->get_operator(), copy_inputs);
}
else
{
copy_ins = add_instruction(ins->get_operator(), copy_inputs, module_args);
}
} }
} }
...@@ -498,7 +492,7 @@ void module::debug_print() const { std::cout << *this << std::endl; } ...@@ -498,7 +492,7 @@ void module::debug_print() const { std::cout << *this << std::endl; }
void module::debug_print(instruction_ref ins, void module::debug_print(instruction_ref ins,
std::unordered_map<instruction_ref, std::string>& names) const std::unordered_map<instruction_ref, std::string>& names) const
{ {
if(ins == this->end()) if(is_end(ins, this->end()))
{ {
std::cout << "End instruction" << std::endl; std::cout << "End instruction" << std::endl;
return; return;
......
...@@ -117,14 +117,43 @@ auto tune_attribute(const std::vector<int64_t>& vec, ...@@ -117,14 +117,43 @@ auto tune_attribute(const std::vector<int64_t>& vec,
return result; return result;
} }
auto tune_pad_attribute(const value& val)
{
std::vector<size_t> vec_attrs = val.to_vector<size_t>();
std::vector<size_t> result(vec_attrs.begin(), vec_attrs.end());
std::copy(vec_attrs.begin(), vec_attrs.end(), std::back_inserter(result));
return result;
}
bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
{ {
bool tuned = false; bool tuned = false;
auto attrs = op.attributes(); auto attrs = op.attributes();
auto val = op.to_value(); auto val = op.to_value();
if(attrs.contains("normalize_padding"))
{
auto padding = val.at(attrs.at("normalize_padding").to<std::string>());
auto padding_size = padding.size();
// for now, assume the dimensions to pad start at dim 2
auto padding_start = 2;
if(padding_size == 2 * (lens.size() - padding_start))
tuned = true;
else if(padding_size != (lens.size() - padding_start))
MIGRAPHX_THROW("inconsistent padding size");
else
{
auto result = tune_pad_attribute(padding);
val["padding"] = result;
op.from_value(val);
tuned = true;
}
}
if(!attrs.contains("normalize_axes")) if(!attrs.contains("normalize_axes"))
{ {
return false; return tuned;
} }
auto attr_v = attrs.at("normalize_axes").without_key(); auto attr_v = attrs.at("normalize_axes").without_key();
......
...@@ -7,7 +7,7 @@ namespace onnx { ...@@ -7,7 +7,7 @@ namespace onnx {
void recalc_conv_attributes(value& v, size_t kdims) void recalc_conv_attributes(value& v, size_t kdims)
{ {
if(v["padding"].size() != kdims) if(not(v["padding"].size() == kdims or v["padding"].size() == kdims * 2))
{ {
v["padding"].resize(kdims); v["padding"].resize(kdims);
std::fill_n(v["padding"].begin(), kdims, 0); std::fill_n(v["padding"].begin(), kdims, 0);
......
...@@ -126,7 +126,7 @@ void check_asym_padding(const onnx_parser::node_info& info, ...@@ -126,7 +126,7 @@ void check_asym_padding(const onnx_parser::node_info& info,
auto left_pad_it = padding.begin(); auto left_pad_it = padding.begin();
auto right_pad_it = left_pad_it + pad_ndims; auto right_pad_it = left_pad_it + pad_ndims;
if(is_asym_padding(padding) or count_include_pad == 1) if(count_include_pad == 1)
{ {
std::vector<int64_t> asym_pads{0, 0, 0, 0}; // don't pad N and C std::vector<int64_t> asym_pads{0, 0, 0, 0}; // don't pad N and C
// add left pads // add left pads
...@@ -134,10 +134,19 @@ void check_asym_padding(const onnx_parser::node_info& info, ...@@ -134,10 +134,19 @@ void check_asym_padding(const onnx_parser::node_info& info,
// add right pads // add right pads
asym_pads.insert(asym_pads.begin() + pad_ndims + 4, right_pad_it, padding.end()); asym_pads.insert(asym_pads.begin() + pad_ndims + 4, right_pad_it, padding.end());
ins = info.add_instruction(make_op("pad", {{"pads", asym_pads}, {"value", pad_val}}), ins); ins = info.add_instruction(make_op("pad", {{"pads", asym_pads}, {"value", pad_val}}), ins);
} std::vector<size_t> new_padding(padding.size());
else // subtract asym padding originally found from parsing the operator
{ std::transform(padding.begin(),
v["padding"] = std::vector<size_t>(left_pad_it, right_pad_it); left_pad_it,
asym_pads.begin() + 2,
new_padding.begin(),
std::minus<size_t>());
std::transform(right_pad_it,
padding.end(),
asym_pads.begin() + pad_ndims + 4,
new_padding.begin() + pad_ndims,
std::minus<size_t>());
v["padding"] = new_padding;
} }
} }
......
...@@ -73,7 +73,7 @@ struct parse_convolution : op_parser<parse_convolution> ...@@ -73,7 +73,7 @@ struct parse_convolution : op_parser<parse_convolution>
values["padding_mode"] = to_value(op::padding_mode_t::same); values["padding_mode"] = to_value(op::padding_mode_t::same);
} }
} }
check_asym_padding(info, l0, padding, values); values["padding"] = std::vector<size_t>(padding.begin(), padding.end());
if(contains(info.attributes, "group")) if(contains(info.attributes, "group"))
{ {
......
#include <migraphx/instruction_ref.hpp>
#include <migraphx/onnx/op_parser.hpp> #include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/onnx_parser.hpp> #include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/checks.hpp> #include <migraphx/onnx/checks.hpp>
...@@ -26,59 +27,41 @@ struct parse_if : op_parser<parse_if> ...@@ -26,59 +27,41 @@ struct parse_if : op_parser<parse_if>
MIGRAPHX_THROW("PARSE_IF: condition input can have only one element!"); MIGRAPHX_THROW("PARSE_IF: condition input can have only one element!");
} }
migraphx::argument cond_arg = args.front()->eval(); std::string then_name = info.name + "_if";
// cond is not constant, need to create sub_modules module_ref then_mdl = parser.prog.create_module(then_name);
if(cond_arg.empty())
{
std::string then_name = info.name + "_if";
module_ref then_mdl = parser.prog.create_module(then_name);
std::string else_name = info.name + "_else";
module_ref else_mdl = parser.prog.create_module(else_name);
// parse the then sub_graph std::string else_name = info.name + "_else";
parser.parse_graph(then_mdl, then_graph); module_ref else_mdl = parser.prog.create_module(else_name);
// parse_the else sub_graph // parse the then sub_graph
parser.parse_graph(else_mdl, else_graph); parser.parse_graph(then_mdl, then_graph);
auto then_out_shapes = then_mdl->get_output_shapes(); // parse_the else sub_graph
auto else_out_shapes = else_mdl->get_output_shapes(); parser.parse_graph(else_mdl, else_graph);
if(not std::equal(then_out_shapes.begin(),
then_out_shapes.end(),
else_out_shapes.begin(),
else_out_shapes.end()))
{
MIGRAPHX_THROW("PARSE_IF: then and else sub_grahps must have same output shapes!");
}
auto ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl}); auto then_out_shapes = then_mdl->get_output_shapes();
auto else_out_shapes = else_mdl->get_output_shapes();
return {ret}; if(not std::equal(then_out_shapes.begin(),
} then_out_shapes.end(),
else else_out_shapes.begin(),
else_out_shapes.end()))
{ {
auto* mod = info.mod; MIGRAPHX_THROW("PARSE_IF: then and else sub_grahps must have same output shapes!");
// then branch }
if(cond_arg.at<bool>())
{
parser.parse_graph(mod, then_graph);
}
// else branch
else
{
parser.parse_graph(mod, else_graph);
}
// inputs of the return instruction are that of the output of the auto if_ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl});
// if instruction auto out_s = if_ret->get_shape();
instruction_ref ret_ins = std::prev(mod->end()); assert(out_s.type() == shape::tuple_type);
auto outputs = ret_ins->inputs();
assert(ret_ins->name() == "@return");
mod->remove_instruction(ret_ins);
return outputs; const auto& vec_shapes = out_s.sub_shapes();
std::vector<instruction_ref> out_inss;
for(std::size_t i = 0; i < vec_shapes.size(); ++i)
{
auto ret = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), if_ret);
out_inss.push_back(ret);
} }
return out_inss;
} }
}; };
......
...@@ -133,18 +133,11 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -133,18 +133,11 @@ struct parse_pooling : op_parser<parse_pooling>
slice_end.begin(), slice_end.begin(),
[](auto i, auto j) { return i + j; }); [](auto i, auto j) { return i + j; });
} }
values["padding"] = std::vector<size_t>(paddings.begin(), paddings.end());
check_asym_padding(info, l0, paddings, values, count_include_pad, pad_val); check_asym_padding(info, l0, paddings, values, count_include_pad, pad_val);
in_lens = l0->get_shape().lens();
for(size_t i = 0; i < kdims; i++)
{
if(values["lengths"][i].to<int64_t>() >
in_lens[i + 2] + 2 * values["padding"][i].to<int64_t>())
{
MIGRAPHX_THROW("PARSE_POOLING: kernel shape is too large");
}
}
op.from_value(values); op.from_value(values);
auto l1 = info.add_instruction(op, l0); auto l1 = info.add_instruction(op, l0);
if(!slice_start.empty()) if(!slice_start.empty())
{ {
......
...@@ -55,7 +55,7 @@ const auto& get_original_idx_op(const std::string& mode) ...@@ -55,7 +55,7 @@ const auto& get_original_idx_op(const std::string& mode)
}}, }},
{"align_corners", {"align_corners",
[=](std::size_t l_in, std::size_t l_out, std::size_t idx, double) { [=](std::size_t l_in, std::size_t l_out, std::size_t idx, double) {
return 1.0 * idx * (l_in - 1.0) / (l_out - 1.0); return (l_out == 1) ? 0.0 : (1.0 * idx * (l_in - 1.0) / (l_out - 1.0));
}}, }},
{"asymmetric", {"asymmetric",
[=](std::size_t, std::size_t, std::size_t idx, double scale) { return idx / scale; }}, [=](std::size_t, std::size_t, std::size_t idx, double scale) { return idx / scale; }},
...@@ -71,6 +71,96 @@ const auto& get_original_idx_op(const std::string& mode) ...@@ -71,6 +71,96 @@ const auto& get_original_idx_op(const std::string& mode)
return idx_ops.at(mode); return idx_ops.at(mode);
} }
static std::vector<int>
calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& vvv_ind,
int i_dim,
const std::vector<std::vector<std::size_t>>& vec_dims,
const shape& in_s)
{
if(i_dim == vvv_ind.size())
{
std::vector<int> vec_ind;
vec_ind.resize(vec_dims.size());
std::transform(vec_dims.begin(), vec_dims.end(), vec_ind.begin(), [&](auto idx) {
return static_cast<int>(in_s.index(idx));
});
return vec_ind;
}
const auto& vv_ind = vvv_ind[i_dim];
const auto& vv_lo = vv_ind.at(0);
std::vector<std::vector<std::size_t>> vec_dims1;
for(std::size_t start = 0; start < vec_dims.size(); start += vv_lo.size())
{
std::transform(vv_lo.begin(),
vv_lo.end(),
vec_dims.begin() + start,
std::back_inserter(vec_dims1),
[](auto i, auto dim) {
dim.push_back(i);
return dim;
});
}
const auto& vv_hi = vv_ind.at(1);
for(std::size_t start = 0; start < vec_dims.size(); start += vv_lo.size())
{
std::transform(vv_hi.begin(),
vv_hi.end(),
vec_dims.begin() + start,
std::back_inserter(vec_dims1),
[](auto i, auto dim) {
dim.push_back(i);
return dim;
});
}
return calc_neighbor_points(vvv_ind, i_dim + 1, vec_dims1, in_s);
}
static std::string get_coord_trans_mode(const onnx_parser::attribute_map& attr)
{
std::string coord_trans_mode = "half_pixel";
if(contains(attr, "coordinate_transformation_mode"))
{
coord_trans_mode = attr.at("coordinate_transformation_mode").s();
// does not support transformation mode "tf_crop_and_resize"
if(coord_trans_mode == "tf_crop_and_resize")
{
MIGRAPHX_THROW("PARSE_RESIZE: \"tf_crop_and_resize\" mode is not supported!");
}
}
return coord_trans_mode;
}
static std::string get_mode(const onnx_parser::attribute_map& attr)
{
std::string mode = "nearest";
if(contains(attr, "mode"))
{
mode = attr.at("mode").s();
if(mode != "nearest" and mode != "linear")
{
MIGRAPHX_THROW("PARSE_RESIZE: only nearest and linear modes are supported!");
}
}
return mode;
}
static std::string get_nearest_mode(const onnx_parser::attribute_map& attr)
{
std::string nearest_mode = "round_prefer_floor";
if(contains(attr, "nearest_mode"))
{
nearest_mode = attr.at("nearest_mode").s();
}
return nearest_mode;
}
struct parse_resize : op_parser<parse_resize> struct parse_resize : op_parser<parse_resize>
{ {
std::vector<op_desc> operators() const { return {{"Resize"}}; } std::vector<op_desc> operators() const { return {{"Resize"}}; }
...@@ -80,42 +170,20 @@ struct parse_resize : op_parser<parse_resize> ...@@ -80,42 +170,20 @@ struct parse_resize : op_parser<parse_resize>
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
std::string coord_trans_mode = "half_pixel"; // coord transform mode
if(contains(info.attributes, "coordinate_transformation_mode")) std::string coord_trans_mode = get_coord_trans_mode(info.attributes);
{
coord_trans_mode = info.attributes.at("coordinate_transformation_mode").s();
// does not support transformation mode "tf_crop_and_resize"
if(coord_trans_mode == "tf_crop_and_resize")
{
MIGRAPHX_THROW("PARSE_RESIZE: \"tf_crop_and_resize\" mode is not supported!");
}
}
// mode: only nearest mode is supported for now // mode: only nearest and linear modes are supported for now
if(contains(info.attributes, "mode")) std::string mode = get_mode(info.attributes);
{
auto mode = info.attributes.at("mode").s();
if(mode != "nearest")
{
MIGRAPHX_THROW("PARSE_RESIZE: only nearest mode is supported!");
}
}
// nearest mode // nearest mode
std::string nearest_mode = "round_prefer_floor"; std::string nearest_mode = get_nearest_mode(info.attributes);
if(contains(info.attributes, "nearest_mode"))
{
nearest_mode = info.attributes.at("nearest_mode").s();
}
// check exclude_outside, only support 0 // check exclude_outside, only support 0
if(contains(info.attributes, "exclude_outside")) if(contains(info.attributes, "exclude_outside") and
info.attributes.at("exclude_outside").i() == 1)
{ {
int exclude_outside = info.attributes.at("exclude_outside").i(); MIGRAPHX_THROW("PARSE_RESIZE: exclude_outside 1 is not supported!");
if(exclude_outside == 1)
{
MIGRAPHX_THROW("PARSE_RESIZE: exclude_outside 1 is not supported!");
}
} }
// input data shape info // input data shape info
...@@ -128,74 +196,164 @@ struct parse_resize : op_parser<parse_resize> ...@@ -128,74 +196,164 @@ struct parse_resize : op_parser<parse_resize>
// scale // scale
std::vector<double> vec_scale; std::vector<double> vec_scale;
// output size is specified in input, so use it as output size for(const auto& arg : args)
if(args.size() == 4 and args.back()->name() != "undefined")
{ {
auto arg_out_s = args[3]->eval(); if(arg->name() == "undefined" or arg == args.front())
check_arg_empty(arg_out_s, "PARSE_RESIZE: dynamic output size is not supported!");
arg_out_s.visit([&](auto ol) { out_lens.assign(ol.begin(), ol.end()); });
if(out_lens.size() != in_lens.size())
{ {
MIGRAPHX_THROW("PARSE_RESIZE: specified output size does not match input size"); continue;
} }
// compute the scale // skipped empty input
vec_scale.resize(in_lens.size()); auto lens = arg->get_shape().lens();
std::transform(in_lens.begin(), if(lens.empty())
in_lens.end(), {
out_lens.begin(), continue;
vec_scale.begin(), }
[](auto iss, auto oss) { return 1.0 * oss / iss; });
}
// need to compute the output lens from input
else
{
auto arg_scale = args[2]->eval();
check_arg_empty(arg_scale, "PARSE_RESIZE: dynamic input scale is not supported!");
arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); }); auto type = arg->get_shape().type();
if(in_lens.size() != vec_scale.size()) // output size
if(type == shape::int64_type)
{ {
MIGRAPHX_THROW("PARSE_RESIZE: ranks of input and scale are different!"); auto arg_out_s = arg->eval();
check_arg_empty(arg_out_s, "PARSE_RESIZE: dynamic output size is not supported!");
arg_out_s.visit([&](auto ol) { out_lens.assign(ol.begin(), ol.end()); });
if(out_lens.size() != in_lens.size())
{
MIGRAPHX_THROW("PARSE_RESIZE: specified output size does not match input size");
}
// compute the scale
vec_scale.resize(in_lens.size());
std::transform(in_lens.begin(),
in_lens.end(),
out_lens.begin(),
vec_scale.begin(),
[](auto iss, auto oss) { return 1.0 * oss / iss; });
} }
else
{
// scale input
if(lens[0] == in_lens.size())
{
auto arg_scale = arg->eval();
check_arg_empty(arg_scale,
"PARSE_RESIZE: dynamic input scale is not supported!");
std::transform( arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); });
in_lens.begin(), if(in_lens.size() != vec_scale.size())
in_lens.end(), {
vec_scale.begin(), MIGRAPHX_THROW("PARSE_RESIZE: ranks of input and scale are different!");
out_lens.begin(), }
[&](auto idx, auto scale) { return static_cast<std::size_t>(idx * scale); });
std::transform(in_lens.begin(),
in_lens.end(),
vec_scale.begin(),
out_lens.begin(),
[&](auto idx, auto scale) {
return static_cast<std::size_t>(idx * scale);
});
}
}
} }
shape out_s{in_s.type(), out_lens}; shape out_s{in_s.type(), out_lens};
std::vector<int> ind(out_s.elements()); std::size_t out_elements = out_s.elements();
auto idx_op = get_original_idx_op(coord_trans_mode);
// map out_idx to in_idx // reshape input to one-dimension
auto nearest_op = get_nearest_op(nearest_mode); std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())};
auto idx_op = get_original_idx_op(coord_trans_mode); args[0] = info.make_contiguous(args[0]);
auto rsp = info.add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]);
if(mode == "nearest")
{
std::vector<int> ind(out_elements);
// map out_idx to in_idx
auto nearest_op = get_nearest_op(nearest_mode);
shape_for_each(out_s, [&](auto idx) {
auto in_idx = idx;
for(auto ii = 0; ii < in_lens.size(); ++ii)
{
auto idx_val = idx_op(in_lens[ii], out_lens[ii], idx[ii], vec_scale[ii]);
in_idx[ii] = nearest_op(in_lens[ii], idx_val);
}
ind[out_s.index(idx)] = static_cast<int64_t>(in_s.index(in_idx));
});
shape ind_s{shape::int32_type, out_lens};
auto ins_ind = info.add_literal(literal(ind_s, ind));
return info.add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind);
}
// linear mode
else
{
auto nearest_floor = get_nearest_op("floor");
auto nearest_ceil = get_nearest_op("ceil");
shape_for_each(out_s, [&](auto idx) { // get the number of dimensions
auto in_idx = idx; std::size_t n_dim = out_lens.size();
for(auto ii = 0; ii < in_lens.size(); ++ii) std::vector<std::vector<std::size_t>> vv_ind(2, std::vector<std::size_t>(out_elements));
std::vector<std::vector<std::vector<std::size_t>>> vvv_ind(n_dim, vv_ind);
std::vector<std::vector<float>> delta(n_dim, std::vector<float>(out_elements));
shape_for_each(out_s, [&](auto idx) {
auto in_idx = idx;
auto out_idx = out_s.index(idx);
for(auto ii = 0; ii < in_lens.size(); ++ii)
{
auto idx_val = idx_op(in_lens[ii], out_lens[ii], idx[ii], vec_scale[ii]);
vvv_ind[ii][0][out_idx] = nearest_floor(in_lens[ii], idx_val);
vvv_ind[ii][1][out_idx] = nearest_ceil(in_lens[ii], idx_val);
delta[ii][out_idx] = idx_val - vvv_ind[ii][0][out_idx];
}
});
std::vector<std::vector<std::size_t>> vec_dims(out_elements);
auto ind = calc_neighbor_points(vvv_ind, 0, vec_dims, in_s);
auto ind_lens = out_lens;
ind_lens[0] *= (std::size_t{1} << n_dim);
shape ind_s{shape::int32_type, ind_lens};
auto ins_ind = info.add_literal(literal(ind_s, ind));
auto data = info.add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind);
auto dim_lens = out_lens;
dim_lens[0] *= (std::size_t{1} << (n_dim - 1));
for(std::size_t i = 0; i < n_dim; ++i)
{ {
auto idx_val = idx_op(in_lens[ii], out_lens[ii], in_idx[ii], vec_scale[ii]); shape dim_s{shape::float_type, dim_lens};
in_idx[ii] = nearest_op(in_lens[ii], idx_val); const auto& dim_delta = delta[n_dim - i - 1];
} std::vector<float> delta_data;
for(std::size_t j = 0; j < dim_lens[0] / out_lens[0]; ++j)
{
delta_data.insert(delta_data.begin(), dim_delta.begin(), dim_delta.end());
}
auto ins_delta = info.add_literal(dim_s, delta_data);
ind[out_s.index(idx)] = static_cast<int64_t>(in_s.index(in_idx)); // slice the data
}); int64_t slc_stride = static_cast<int64_t>(dim_lens[0]);
auto low = info.add_instruction(
make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {slc_stride}}}),
data);
auto hi = info.add_instruction(
make_op("slice",
{{"axes", {0}}, {"starts", {slc_stride}}, {"ends", {2 * slc_stride}}}),
data);
auto diff = info.add_instruction(make_op("sub"), hi, low);
auto ddf = info.add_instruction(make_op("mul"), diff, ins_delta);
data = info.add_instruction(make_op("add"), ddf, low);
dim_lens[0] /= 2;
}
// reshape input to one-dimension return data;
std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())}; }
shape ind_s{shape::int32_type, out_lens};
auto arg_cont = info.make_contiguous(args[0]);
auto rsp = info.add_instruction(make_op("reshape", {{"dims", rsp_lens}}), arg_cont);
auto ins_ind = info.add_literal(literal(ind_s, ind));
return info.add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind);
} }
}; };
} // namespace onnx } // namespace onnx
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -20,18 +20,15 @@ struct parse_slice : op_parser<parse_slice> ...@@ -20,18 +20,15 @@ struct parse_slice : op_parser<parse_slice>
{ {
op::slice op; op::slice op;
std::vector<int64_t> steps;
// slice can have up to 5 inputs, we first check the 5th one // slice can have up to 5 inputs, we first check the 5th one
// to decide whether MIGRAPHX can handle this slice // to decide whether MIGRAPHX can handle this slice
if(args.size() == 5) if(args.size() == 5)
{ {
migraphx::argument step_arg = args.back()->eval(); migraphx::argument step_arg = args.back()->eval();
check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice"); check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice");
std::vector<int> steps;
step_arg.visit([&](auto s) { steps.assign(s.begin(), s.end()); }); step_arg.visit([&](auto s) { steps.assign(s.begin(), s.end()); });
if(!std::all_of(steps.begin(), steps.end(), [](auto s) { return s == 1; }))
{
MIGRAPHX_THROW("PARSE_SLICE: cannot handle step other than 1");
}
} }
if(args.size() >= 4) if(args.size() >= 4)
...@@ -77,7 +74,38 @@ struct parse_slice : op_parser<parse_slice> ...@@ -77,7 +74,38 @@ struct parse_slice : op_parser<parse_slice>
op.axes = axes; op.axes = axes;
} }
return info.add_instruction(op, args[0]); std::vector<int64_t> raxes;
assert(steps.empty() or steps.size() == op.axes.size());
assert(op.axes.size() == op.starts.size());
assert(op.axes.size() == op.ends.size());
for(auto i : range(steps.size()))
{
if(steps[i] >= 0)
continue;
op.starts[i] += 1;
if(op.starts[i] == 0)
op.starts[i] = INT_MAX;
op.ends[i] += 1;
raxes.push_back(op.axes[i]);
std::swap(op.starts[i], op.ends[i]);
}
auto ins = info.add_instruction(op, args[0]);
if(not raxes.empty())
ins = info.add_instruction(make_op("reverse", {{"axes", raxes}}), ins);
if(std::any_of(steps.begin(), steps.end(), [](auto s) { return std::abs(s) != 1; }))
{
std::vector<int64_t> nsteps;
std::transform(steps.begin(), steps.end(), std::back_inserter(nsteps), [](auto s) {
return std::abs(s);
});
return ins = info.add_instruction(
make_op("step", {{"axes", op.axes}, {"steps", nsteps}}), ins);
}
else
return ins;
} }
}; };
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/iterator.hpp>
#include <migraphx/algorithm.hpp> #include <migraphx/algorithm.hpp>
#include <migraphx/output_iterator.hpp> #include <migraphx/output_iterator.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
...@@ -242,20 +243,10 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -242,20 +243,10 @@ std::vector<argument> generic_eval(const module* mod,
return generic_eval(smod, ctx, inputs, results, trace); return generic_eval(smod, ctx, inputs, results, trace);
}; };
if(not mod_args.empty()) results.emplace(ins, trace(ins, [&] {
{ return ins->normalized_operator().compute(
results.emplace(ins, trace(ins, [&] { ctx, ins->get_shape(), values, mod_args, module_eval);
return ins->normalized_operator().compute( }));
values, mod_args, module_eval);
}));
}
else
{
results.emplace(ins, trace(ins, [&] {
return ins->normalized_operator().compute(
ctx, ins->get_shape(), values);
}));
}
} }
assert(results.find(ins) != results.end()); assert(results.find(ins) != results.end());
} }
...@@ -484,6 +475,14 @@ double common_average(const std::vector<double>& v) ...@@ -484,6 +475,14 @@ double common_average(const std::vector<double>& v)
return total / std::distance(v.begin() + n, v.end() - n); return total / std::distance(v.begin() + n, v.end() - n);
} }
std::string perf_group(const operation& op)
{
auto attr = op.attributes();
if(attr.contains("group"))
return attr.at("group").to<std::string>();
return op.name();
}
void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) const void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) const
{ {
auto& ctx = this->impl->ctx; auto& ctx = this->impl->ctx;
...@@ -538,7 +537,7 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) ...@@ -538,7 +537,7 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
for(auto&& p : ins_vec) for(auto&& p : ins_vec)
{ {
double avg = common_average(p.second); double avg = common_average(p.second);
op_times[p.first->name()] += avg; op_times[perf_group(p.first->get_operator())] += avg;
total_instruction_time += avg; total_instruction_time += avg;
} }
double calculate_overhead_time = total_time - total_instruction_time; double calculate_overhead_time = total_time - total_instruction_time;
...@@ -590,7 +589,7 @@ void program::debug_print(instruction_ref ins) const ...@@ -590,7 +589,7 @@ void program::debug_print(instruction_ref ins) const
{ {
std::unordered_map<instruction_ref, std::string> names; std::unordered_map<instruction_ref, std::string> names;
if(std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](const auto& pp) { if(std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](const auto& pp) {
return (pp.second.end() == ins); return is_end(pp.second.end(), ins);
})) }))
{ {
std::cout << "End instruction" << std::endl; std::cout << "End instruction" << std::endl;
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/iterator.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
...@@ -122,11 +123,11 @@ struct stream_info ...@@ -122,11 +123,11 @@ struct stream_info
std::unordered_map<instruction_ref, std::deque<partition>> partitions; std::unordered_map<instruction_ref, std::deque<partition>> partitions;
partitions.reserve(weights.size()); partitions.reserve(weights.size());
fix([&](auto self, auto ins, auto& part) { fix([&](auto self, auto ins, auto& part) {
assert(ins != p.end()); assert(not is_end(ins, p.end()));
if(contains(partitions, ins))
return;
if(not p.has_instruction(ins)) if(not p.has_instruction(ins))
return; return;
if(contains(partitions, ins))
return;
// Add an entry so we know the instruction was visited // Add an entry so we know the instruction was visited
partitions[ins]; partitions[ins];
......
...@@ -149,7 +149,8 @@ struct find_mul_slice_conv ...@@ -149,7 +149,8 @@ struct find_mul_slice_conv
assert(ins->get_shape().lens() == slice1->get_shape().lens()); assert(ins->get_shape().lens() == slice1->get_shape().lens());
p.replace_instruction(ins, slice1); p.replace_instruction(ins, slice1);
// TODO: Check each slice doesn't overlap and that it occurs after slice_ins // TODO: Check each slice doesn't overlap and that it occurs after slice_ins
for(auto output : conv_ins->outputs()) auto outputs = conv_ins->outputs();
for(auto output : outputs)
if(output != slice_ins) if(output != slice_ins)
instruction::replace_argument(output, conv_ins, new_conv); instruction::replace_argument(output, conv_ins, new_conv);
} }
...@@ -554,7 +555,8 @@ struct find_splits ...@@ -554,7 +555,8 @@ struct find_splits
auto split = i->inputs()[split_idx]; auto split = i->inputs()[split_idx];
assert(split->name() == "slice"); assert(split->name() == "slice");
// Insert contiguous for reshapes // Insert contiguous for reshapes
for(auto output : i->outputs()) auto outputs = i->outputs();
for(auto output : outputs)
{ {
if(not contains({"reshape", "squeeze", "unsqueeze"}, output->name())) if(not contains({"reshape", "squeeze", "unsqueeze"}, output->name()))
continue; continue;
......
...@@ -14,6 +14,8 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary> ...@@ -14,6 +14,8 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary>
return pack_join(self.reflect_base(self, f), pack(f(self.algo, "algo"))); return pack_join(self.reflect_base(self, f), pack(f(self.algo, "algo")));
} }
std::string group() const { return this->name() + "::" + algo; }
std::string name() const { return "dnnl::binary"; } std::string name() const { return "dnnl::binary"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
......
...@@ -40,6 +40,9 @@ struct dnnl_convolution ...@@ -40,6 +40,9 @@ struct dnnl_convolution
auto dilation = op.dilation; auto dilation = op.dilation;
std::transform( std::transform(
dilation.begin(), dilation.end(), dilation.begin(), [](auto x) { return x - 1; }); dilation.begin(), dilation.end(), dilation.begin(), [](auto x) { return x - 1; });
auto kdims = op.kdims();
std::vector<size_t> padding_l(op.padding.begin(), op.padding.begin() + kdims);
std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end());
return {dnnl::prop_kind::forward_inference, return {dnnl::prop_kind::forward_inference,
dnnl::algorithm::convolution_auto, dnnl::algorithm::convolution_auto,
m.at(DNNL_ARG_SRC), m.at(DNNL_ARG_SRC),
...@@ -47,8 +50,8 @@ struct dnnl_convolution ...@@ -47,8 +50,8 @@ struct dnnl_convolution
m.at(DNNL_ARG_DST), m.at(DNNL_ARG_DST),
to_dnnl_dims(op.stride), to_dnnl_dims(op.stride),
to_dnnl_dims(dilation), to_dnnl_dims(dilation),
to_dnnl_dims(op.padding), to_dnnl_dims(padding_l),
to_dnnl_dims(op.padding)}; to_dnnl_dims(padding_r)};
} }
}; };
......
...@@ -17,6 +17,8 @@ struct dnnl_eltwise : dnnl_op<dnnl_eltwise, dnnl::eltwise_forward> ...@@ -17,6 +17,8 @@ struct dnnl_eltwise : dnnl_op<dnnl_eltwise, dnnl::eltwise_forward>
pack(f(self.algo, "algo"), f(self.alpha, "alpha"), f(self.beta, "beta"))); pack(f(self.algo, "algo"), f(self.alpha, "alpha"), f(self.beta, "beta")));
} }
std::string group() const { return this->name() + "::" + algo; }
std::string name() const { return "dnnl::eltwise"; } std::string name() const { return "dnnl::eltwise"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
......
...@@ -74,6 +74,25 @@ struct dnnl_op : auto_register_op<Derived> ...@@ -74,6 +74,25 @@ struct dnnl_op : auto_register_op<Derived>
return reflect_base(self, f); return reflect_base(self, f);
} }
std::string group() const
{
const auto& self = static_cast<const Derived&>(*this);
return self.name();
}
value attributes() const
{
std::vector<std::string> names;
std::transform(post_ops.begin(), post_ops.end(), std::back_inserter(names), [](auto&& op) {
return op.algo;
});
const auto& self = static_cast<const Derived&>(*this);
auto g = self.group();
if(not names.empty())
g += "<" + join_strings(names, ",") + ">";
return {{"group", g}};
}
std::size_t get_extra_post_op_args() const std::size_t get_extra_post_op_args() const
{ {
return std::count_if(post_ops.begin(), post_ops.end(), [](const auto& po) { return std::count_if(post_ops.begin(), post_ops.end(), [](const auto& po) {
......
...@@ -66,7 +66,10 @@ struct cpu_im2col ...@@ -66,7 +66,10 @@ struct cpu_im2col
} }
static std::string name() { return "cpu::im2col"; } static std::string name() { return "cpu::im2col"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const
{
return op.normalize_compute_shape(inputs);
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
......
...@@ -63,7 +63,7 @@ struct cpu_pooling : auto_register_op<cpu_pooling<Op>> ...@@ -63,7 +63,7 @@ struct cpu_pooling : auto_register_op<cpu_pooling<Op>>
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
inputs.pop_back(); inputs.pop_back();
return op.compute_shape(inputs); return op.normalize_compute_shape(inputs);
} }
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
...@@ -129,15 +129,18 @@ struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::po ...@@ -129,15 +129,18 @@ struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::po
dnnl::pooling_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const dnnl::pooling_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{ {
auto algo = op.mode == "max" ? dnnl::algorithm::pooling_max : dnnl::algorithm::pooling_avg; auto algo = op.mode == "max" ? dnnl::algorithm::pooling_max : dnnl::algorithm::pooling_avg;
auto kdims = op.kdims();
std::vector<size_t> padding_l(op.padding.begin(), op.padding.begin() + kdims);
std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end());
return {dnnl::prop_kind::forward_inference, return {dnnl::prop_kind::forward_inference,
algo, algo,
m.at(DNNL_ARG_SRC), m.at(DNNL_ARG_SRC),
m.at(DNNL_ARG_DST), m.at(DNNL_ARG_DST),
to_dnnl_dims(op.stride), to_dnnl_dims(op.stride),
to_dnnl_dims(op.lengths), to_dnnl_dims(op.lengths),
to_dnnl_dims(op.padding), to_dnnl_dims(padding_l),
to_dnnl_dims(op.padding)}; to_dnnl_dims(padding_r)};
} }
}; };
......
...@@ -68,6 +68,7 @@ add_library(migraphx_device ...@@ -68,6 +68,7 @@ add_library(migraphx_device
device/reduce_sum.cpp device/reduce_sum.cpp
device/reduce_prod.cpp device/reduce_prod.cpp
device/relu.cpp device/relu.cpp
device/reverse.cpp
device/rnn_variable_seq_lens.cpp device/rnn_variable_seq_lens.cpp
device/round.cpp device/round.cpp
device/rsqrt.cpp device/rsqrt.cpp
...@@ -140,6 +141,7 @@ add_library(migraphx_gpu ...@@ -140,6 +141,7 @@ add_library(migraphx_gpu
pooling.cpp pooling.cpp
preallocate_param.cpp preallocate_param.cpp
quant_convolution.cpp quant_convolution.cpp
reverse.cpp
rnn_variable_seq_lens.cpp rnn_variable_seq_lens.cpp
rocblas.cpp rocblas.cpp
softmax.cpp softmax.cpp
...@@ -197,6 +199,7 @@ register_migraphx_gpu_ops(hip_ ...@@ -197,6 +199,7 @@ register_migraphx_gpu_ops(hip_
reduce_prod reduce_prod
reduce_sum reduce_sum
relu relu
reverse
round round
rsqrt rsqrt
sigmoid sigmoid
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment