Commit f7079e51 authored by Paul's avatar Paul
Browse files

Merge

parents 79eac1b8 f6e22d56
...@@ -94,7 +94,7 @@ struct onnx_parser ...@@ -94,7 +94,7 @@ struct onnx_parser
node_map nodes; node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
program prog = program(); program prog = program();
shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0}; shape::dynamic_dimension default_dyn_dim_value = {1, 1};
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims; std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims; std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims;
bool use_dyn_output = false; bool use_dyn_output = false;
......
...@@ -46,14 +46,14 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs) ...@@ -46,14 +46,14 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
auto dim_val = options.default_dim_value; auto dim_val = options.default_dim_value;
if(dim_val != 0) if(dim_val != 0)
{ {
if(options.default_dyn_dim_value != shape::dynamic_dimension{1, 1, 0}) if(options.default_dyn_dim_value != shape::dynamic_dimension{1, 1})
{ {
MIGRAPHX_THROW("PARSE_ONNX_FROM: both default_dim_value and default_dyn_dim_value" MIGRAPHX_THROW("PARSE_ONNX_FROM: both default_dim_value and default_dyn_dim_value"
"set to non-default value"); "set to non-default value");
} }
else else
{ {
parser.default_dyn_dim_value = {dim_val, dim_val, 0}; parser.default_dyn_dim_value = {dim_val, dim_val};
} }
} }
else else
......
...@@ -491,7 +491,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t, ...@@ -491,7 +491,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
return default_dyn_dim_value; return default_dyn_dim_value;
} }
std::size_t tmp = d.dim_value(); std::size_t tmp = d.dim_value();
return {tmp, tmp, 0}; return {tmp, tmp};
} }
else else
{ {
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp> #include <migraphx/tune_axis.hpp>
#include <migraphx/common.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -47,18 +48,15 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear> ...@@ -47,18 +48,15 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
auto input_lens = args[0]->get_shape().lens(); auto input_lens = args[0]->get_shape().lens();
auto n_dim = input_lens.size(); auto n_dim = input_lens.size();
instruction_ref y_scale; instruction_ref y_scale = args[1];
if(args[1]->get_shape().elements() != 1) if(args[1]->get_shape().elements() != 1)
{ {
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name); auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
y_scale = info.add_instruction( y_scale = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]); make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]);
} }
else
{ auto common_args = add_common_args(*info.mod, {args[0], y_scale});
y_scale = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
args[1]);
}
if(args.size() == 3) if(args.size() == 3)
{ {
...@@ -76,10 +74,10 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear> ...@@ -76,10 +74,10 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
make_op("multibroadcast", {{"out_lens", input_lens}}), y_zero_point); make_op("multibroadcast", {{"out_lens", input_lens}}), y_zero_point);
} }
return info.add_instruction(make_op("quantizelinear"), args[0], y_scale, y_zero_point); common_args.push_back(y_zero_point);
} }
return info.add_instruction(make_op("quantizelinear"), args[0], y_scale); return info.add_instruction(make_op("quantizelinear"), common_args);
} }
}; };
......
...@@ -53,8 +53,8 @@ struct parse_reshape : op_parser<parse_reshape> ...@@ -53,8 +53,8 @@ struct parse_reshape : op_parser<parse_reshape>
s.visit([&](auto v) { copy(v, std::back_inserter(dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(dims)); });
} }
return info.add_instruction(make_op("reshape", {{"dims", dims}}), auto cont = info.add_instruction(make_op("contiguous"), args[0]);
info.make_contiguous(args[0])); return info.add_instruction(make_op("reshape", {{"dims", dims}}), cont);
} }
}; };
......
...@@ -86,12 +86,21 @@ struct module_pm : module_pass_manager ...@@ -86,12 +86,21 @@ struct module_pm : module_pass_manager
assert(mod); assert(mod);
return *mod; return *mod;
} }
virtual module* create_module(const std::string& name) override virtual module* create_module(const std::string& name) override
{ {
assert(prog); assert(prog);
return prog->create_module(name); return prog->create_module(name);
} }
virtual module* get_common_parent() override { return common_parent; } virtual module* get_common_parent() override { return common_parent; }
virtual module* get_root_module() override
{
assert(prog);
return prog->get_main_module();
}
virtual void run_pass(const pass& p) override virtual void run_pass(const pass& p) override
{ {
trace("Pass: ", p.name()); trace("Pass: ", p.name());
......
...@@ -331,7 +331,8 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -331,7 +331,8 @@ std::vector<argument> generic_eval(const module* mod,
MIGRAPHX_THROW("Parameter not found: " + param_name); MIGRAPHX_THROW("Parameter not found: " + param_name);
auto param = params[param_name]; auto param = params[param_name];
// TODO: may want to check correct number of dimensions and/or was within bounds // TODO: may want to check correct number of dimensions and/or was within bounds
if(not ins->get_shape().dynamic() and param.get_shape() != ins->get_shape()) if(not ins->get_shape().any_of_dynamic() and
param.get_shape() != ins->get_shape())
{ {
MIGRAPHX_THROW("Incorrect shape {" + to_string(param.get_shape()) + MIGRAPHX_THROW("Incorrect shape {" + to_string(param.get_shape()) +
"} for parameter: " + param_name + "} for parameter: " + param_name +
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/promote_literals.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/module.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void promote_literals::apply(module_pass_manager& mpm) const
{
module& m = mpm.get_module();
module_ref root_module = mpm.get_root_module();
if(m.name() == "main")
return;
for(auto ins : iterator_for(m))
{
if(ins->name() == "@literal")
{
auto new_lit = root_module->add_literal(ins->get_literal());
for(auto out_ins : ins->outputs())
{
out_ins->replace_argument(out_ins, ins, new_lit);
}
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -74,13 +74,23 @@ struct shape_impl ...@@ -74,13 +74,23 @@ struct shape_impl
shape_impl(shape::type_t t, shape_impl(shape::type_t t,
std::vector<std::size_t> mins, std::vector<std::size_t> mins,
std::vector<std::size_t> maxes, std::vector<std::size_t> maxes,
std::vector<std::size_t> opts) std::vector<std::set<std::size_t>> optimals_list)
: m_type(t) : m_type(t)
{ {
assert(mins.size() == maxes.size() and maxes.size() == opts.size()); if(optimals_list.empty())
{
for(size_t i = 0; i < mins.size(); ++i)
{
m_dyn_dims.push_back(shape::dynamic_dimension{mins[i], maxes[i]});
}
}
else
{
assert(mins.size() == maxes.size() and maxes.size() == optimals_list.size());
for(size_t i = 0; i < mins.size(); ++i) for(size_t i = 0; i < mins.size(); ++i)
{ {
m_dyn_dims.push_back(shape::dynamic_dimension{mins[i], maxes[i], opts[i]}); m_dyn_dims.push_back(shape::dynamic_dimension{mins[i], maxes[i], optimals_list[i]});
}
} }
} }
...@@ -147,7 +157,7 @@ struct shape_impl ...@@ -147,7 +157,7 @@ struct shape_impl
std::transform(m_dyn_dims.cbegin(), std::transform(m_dyn_dims.cbegin(),
m_dyn_dims.cend(), m_dyn_dims.cend(),
ret.begin(), ret.begin(),
[](shape::dynamic_dimension x) { return x.min; }); [](const shape::dynamic_dimension& x) { return x.min; });
return ret; return ret;
} }
...@@ -157,19 +167,20 @@ struct shape_impl ...@@ -157,19 +167,20 @@ struct shape_impl
std::transform(m_dyn_dims.cbegin(), std::transform(m_dyn_dims.cbegin(),
m_dyn_dims.cend(), m_dyn_dims.cend(),
ret.begin(), ret.begin(),
[](shape::dynamic_dimension x) { return x.max; }); [](const shape::dynamic_dimension& x) { return x.max; });
return ret; return ret;
} }
std::vector<std::size_t> opt_lens() const std::vector<std::set<std::size_t>> opt_lens() const
{ {
std::vector<std::size_t> ret(m_dyn_dims.size()); std::vector<std::set<std::size_t>> ret(m_dyn_dims.size());
std::transform(m_dyn_dims.cbegin(), std::transform(m_dyn_dims.cbegin(),
m_dyn_dims.cend(), m_dyn_dims.cend(),
ret.begin(), ret.begin(),
[](shape::dynamic_dimension x) { return x.opt; }); [](const shape::dynamic_dimension& x) { return x.optimals; });
return ret; return ret;
} }
// Does the shape skip over elements? // Does the shape skip over elements?
bool skips() const bool skips() const
{ {
...@@ -240,8 +251,9 @@ shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims) ...@@ -240,8 +251,9 @@ shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims)
shape::shape(type_t t, shape::shape(type_t t,
std::vector<std::size_t> mins, std::vector<std::size_t> mins,
std::vector<std::size_t> maxes, std::vector<std::size_t> maxes,
std::vector<std::size_t> opts) std::vector<std::set<std::size_t>> optimals_list)
: impl(std::make_shared<shape_impl>(t, std::move(mins), std::move(maxes), std::move(opts))) : impl(std::make_shared<shape_impl>(
t, std::move(mins), std::move(maxes), std::move(optimals_list)))
{ {
} }
...@@ -469,12 +481,44 @@ shape shape::with_type(type_t t) const ...@@ -469,12 +481,44 @@ shape shape::with_type(type_t t) const
shape shape::to_dynamic() const shape shape::to_dynamic() const
{ {
if(not sub_shapes().empty())
{
std::vector<shape> subs;
std::transform(sub_shapes().cbegin(),
sub_shapes().cend(),
std::back_inserter(subs),
[](auto s) { return s.to_dynamic(); });
return {subs};
}
if(this->dynamic()) if(this->dynamic())
{ {
return *this; return *this;
} }
std::vector<std::size_t> zeroes(this->ndim(), 0); return {type(), lens(), lens(), {}};
return {type(), lens(), lens(), zeroes}; }
shape shape::to_static(std::size_t x) const
{
if(not sub_shapes().empty())
{
std::vector<shape> subs;
std::transform(sub_shapes().cbegin(),
sub_shapes().cend(),
std::back_inserter(subs),
[&](auto s) { return s.to_static(x); });
return {subs};
}
if(not this->dynamic())
{
return *this;
}
auto static_lens = this->max_lens();
std::transform(static_lens.begin(),
static_lens.end(),
this->dyn_dims().cbegin(),
static_lens.begin(),
[&](auto sl, auto dd) { return dd.is_fixed() ? sl : x; });
return {type(), static_lens};
} }
std::size_t shape::element_space() const { return impl->element_space(); } std::size_t shape::element_space() const { return impl->element_space(); }
...@@ -506,23 +550,22 @@ std::vector<std::size_t> shape::max_lens() const ...@@ -506,23 +550,22 @@ std::vector<std::size_t> shape::max_lens() const
return this->dynamic() ? impl->max_lens() : this->lens(); return this->dynamic() ? impl->max_lens() : this->lens();
} }
std::vector<std::size_t> shape::opt_lens() const std::vector<std::set<std::size_t>> shape::opt_lens() const { return impl->opt_lens(); }
{
return this->dynamic() ? impl->opt_lens() : this->lens();
}
bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max; } bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max; }
bool shape::dynamic_dimension::has_optimal() const { return opt != 0; } bool shape::dynamic_dimension::has_optimal() const { return not optimals.empty(); }
shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const std::size_t& x) shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const std::size_t& x)
{ {
this->min += x; this->min += x;
this->max += x; this->max += x;
if(this->opt != 0) std::set<std::size_t> new_optimals;
{ std::transform(this->optimals.begin(),
this->opt += x; this->optimals.end(),
}; std::inserter(new_optimals, new_optimals.begin()),
[&x](const auto& opt) { return (opt + x); });
this->optimals = new_optimals;
return *this; return *this;
} }
...@@ -532,19 +575,23 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const std::size_t ...@@ -532,19 +575,23 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const std::size_t
assert(this->max >= x); assert(this->max >= x);
this->min -= x; this->min -= x;
this->max -= x; this->max -= x;
if(this->opt != 0) std::set<std::size_t> new_optimals;
{ std::transform(this->optimals.begin(),
assert(this->opt >= x); this->optimals.end(),
this->opt -= x; std::inserter(new_optimals, new_optimals.begin()),
} [&x](const auto& opt) {
assert(opt >= x);
return (opt - x);
});
this->optimals = new_optimals;
return *this; return *this;
} }
bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{ {
// don't check opt if both are fixed // don't check optimals if both are fixed
return (x.min == y.min and x.max == y.max and return (x.min == y.min and x.max == y.max and
((x.is_fixed() and y.is_fixed()) or (x.opt == y.opt))); ((x.is_fixed() and y.is_fixed()) or (x.optimals == y.optimals)));
} }
bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
...@@ -553,7 +600,7 @@ bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimensio ...@@ -553,7 +600,7 @@ bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimensio
} }
std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x) std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x)
{ {
os << "[" << x.min << ", " << x.max << ", " << x.opt << "]"; os << "[ " << x.min << ", " << x.max << ", {" << migraphx::to_string_range(x.optimals) << "} ]";
return os; return os;
} }
...@@ -665,8 +712,10 @@ void migraphx_from_value(const value& v, shape& s) ...@@ -665,8 +712,10 @@ void migraphx_from_value(const value& v, shape& s)
std::transform(v_dd.begin(), v_dd.end(), dyn_dims.begin(), [](migraphx::value x) { std::transform(v_dd.begin(), v_dd.end(), dyn_dims.begin(), [](migraphx::value x) {
auto x_min = x.at("min").template to<size_t>(); auto x_min = x.at("min").template to<size_t>();
auto x_max = x.at("max").template to<size_t>(); auto x_max = x.at("max").template to<size_t>();
auto x_opt = x.at("opt").template to<size_t>(); auto v_optimals = x.at("optimals");
return shape::dynamic_dimension{x_min, x_max, x_opt}; std::set<size_t> set_x_optimals =
from_value<std::set<std::size_t>>(x.at("optimals"));
return shape::dynamic_dimension{x_min, x_max, set_x_optimals};
}); });
s = shape{shape::parse_type(t), dyn_dims}; s = shape{shape::parse_type(t), dyn_dims};
......
...@@ -52,8 +52,9 @@ auto op_lit_broadcast(std::string op, std::string x, std::string y) ...@@ -52,8 +52,9 @@ auto op_lit_broadcast(std::string op, std::string x, std::string y)
auto conv_const_weights() auto conv_const_weights()
{ {
return match::name("convolution")(match::used_once(), return match::name("convolution")(
match::args(match::any(), match::is_constant().bind("w"))); match::used_once(),
match::args(match::none_of(match::is_constant()), match::is_constant().bind("w")));
} }
auto reduction() { return match::name_contains("reduce"); } auto reduction() { return match::name_contains("reduce"); }
...@@ -203,6 +204,7 @@ struct find_mul_slice_conv ...@@ -203,6 +204,7 @@ struct find_mul_slice_conv
} }
}; };
<<<<<<< HEAD
struct find_mul_dot struct find_mul_dot
{ {
auto matcher() const auto matcher() const
...@@ -332,6 +334,14 @@ struct find_dot_mul ...@@ -332,6 +334,14 @@ struct find_dot_mul
}; };
// a * (x + b) => a * x + a * b // a * (x + b) => a * x + a * b
=======
// ******************************
// a * (x + b) => a * x + a * b
// ******************************
// When a * (x + b) is followed by another add of constant, then the
// additional add can be const folded. Also, better fusions can be applied
// when the add comes after.
>>>>>>> develop
struct find_mul_add struct find_mul_add
{ {
auto matcher() const auto matcher() const
...@@ -396,6 +406,32 @@ struct find_dot_add ...@@ -396,6 +406,32 @@ struct find_dot_add
} }
}; };
struct find_conv_add
{
auto matcher() const
{
auto add = match::name("add")(
match::either_arg(0, 1)(match::any().bind("x"),
match::any_of(match::is_constant()).bind("a")),
match::used_once());
return match::name("convolution")(match::used_once(),
match::args(add, match::is_constant().bind("w")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
auto x_ins = r.instructions["x"];
auto w_ins = r.instructions["w"];
auto conv1 = m.insert_instruction(ins, ins->get_operator(), a_ins, w_ins);
auto conv2 = m.insert_instruction(ins, ins->get_operator(), x_ins, w_ins);
m.replace_instruction(ins, make_op("add"), conv1, conv2);
}
};
struct find_add_lit_broadcast struct find_add_lit_broadcast
{ {
auto matcher() const auto matcher() const
...@@ -1369,6 +1405,7 @@ void simplify_algebra::apply(module& m) const ...@@ -1369,6 +1405,7 @@ void simplify_algebra::apply(module& m) const
find_neg_unit_ops{}, find_neg_unit_ops{},
find_zero_ops{}, find_zero_ops{},
find_dot_add{}, find_dot_add{},
find_conv_add{},
find_div_const{}, find_div_const{},
find_sub_const{}, find_sub_const{},
find_rsqrt{}, find_rsqrt{},
......
...@@ -762,7 +762,7 @@ struct find_transpose_slice ...@@ -762,7 +762,7 @@ struct find_transpose_slice
return; return;
// Compute axis before transpose to use for unsqueeze // Compute axis before transpose to use for unsqueeze
auto perm = ins->get_operator().to_value()["permutation"].to_vector<int64_t>(); auto perm = ins->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto preaxis = std::find(perm.begin(), perm.end(), axis) - perm.begin(); auto preaxis = perm[axis];
// Make unsqueeze // Make unsqueeze
std::vector<int64_t> steps(sdistance.size()); std::vector<int64_t> steps(sdistance.size());
std::transform( std::transform(
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/split_single_dyn_dim.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct dynamic_dimensions_check
{
std::string dyn_param_str;
size_t dyn_index;
size_t min_dim;
size_t max_dim;
};
optional<dynamic_dimensions_check>
has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes)
{
// True if parameters contain exactly one dynamic shape with exactly one non-fixed
// dynamic_dimension.
auto is_dynamic = [](const auto& p) { return p.second.dynamic(); };
auto ps_it = std::find_if(param_shapes.begin(), param_shapes.end(), is_dynamic);
if(ps_it == param_shapes.end())
return std::nullopt;
// Check if there is a second dynamic parameter
if(std::any_of(std::next(ps_it), param_shapes.end(), is_dynamic))
return std::nullopt;
const auto& dds = ps_it->second.dyn_dims();
auto is_non_fixed = [](const auto& dd) { return not dd.is_fixed(); };
auto dds_it = std::find_if(dds.begin(), dds.end(), is_non_fixed);
if(dds_it == dds.end())
return std::nullopt;
// Check if there is a second non-fixed dynamic_dimension
if(std::any_of(std::next(dds_it), dds.end(), is_non_fixed))
return std::nullopt;
return dynamic_dimensions_check{ps_it->first,
static_cast<std::size_t>(std::distance(dds.begin(), dds_it)),
dds_it->min,
dds_it->max};
}
/**
* Makes all the shapes in the dynamic_dimension range.
* Probably won't work for `if` and `loop` instructions, depending on how the submodules for those
* work. Inserts select_module instruction to the top. Replaces return, bypassing other
* instructions.
*/
void split_single_dyn_dim::apply(module_pass_manager& mpm) const
{
module_ref mm = &mpm.get_module();
auto param_names = mm->get_parameter_names();
auto param_shapes = mm->get_parameter_shapes();
optional<dynamic_dimensions_check> dd_check = has_one_dyn_dim(param_shapes);
if(dd_check.has_value())
{
const auto& dyn_param = mm->get_parameter(dd_check->dyn_param_str);
auto dyn_param_shape = mm->get_parameter_shape(dd_check->dyn_param_str);
std::vector<module_ref> submodules;
// create submodules for each dimension size
for(size_t dim_size : migraphx::range(dd_check->min_dim, dd_check->max_dim + 1))
{
auto* submod = mpm.create_module("dim_" + std::to_string(dim_size));
// instruction map for new static shaped submodule parameters
std::unordered_map<instruction_ref, instruction_ref> map_ins;
// create static shape using dim_size
auto static_lens = dyn_param_shape.max_lens();
static_lens.at(dd_check->dyn_index) = dim_size;
map_ins[dyn_param] = submod->add_parameter(
dd_check->dyn_param_str, migraphx::shape{dyn_param_shape.type(), static_lens});
auto outputs = submod->add_instructions(mm, map_ins);
submod->add_return({outputs});
submodules.push_back(submod);
}
// redirect to select_module operator and return
std::vector<instruction_ref> sm_inputs;
std::transform(param_names.cbegin(),
param_names.cend(),
std::back_inserter(sm_inputs),
[&](auto pn) { return mm->get_parameter(pn); });
auto output_shapes = mm->get_output_shapes();
migraphx::shape out_attr = migraphx::shape{output_shapes};
auto sm_ins = mm->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
sm_inputs,
submodules);
std::vector<instruction_ref> outputs(output_shapes.size());
for(size_t i = 0; i < output_shapes.size(); ++i)
{
outputs.at(i) =
mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", i}}), sm_ins);
}
mm->replace_return(outputs);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#define MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_PARALLEL_HPP #define MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_PARALLEL_HPP
// #define MIGRAPHX_DISABLE_OMP // #define MIGRAPHX_DISABLE_OMP
#include <cmath>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#ifdef MIGRAPHX_DISABLE_OMP #ifdef MIGRAPHX_DISABLE_OMP
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
......
...@@ -82,7 +82,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -82,7 +82,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
simplify_algebra{}, simplify_algebra{},
simplify_reshapes{}, simplify_reshapes{},
layout_nhwc{},
dead_code_elimination{}, dead_code_elimination{},
simplify_reshapes{}, simplify_reshapes{},
simplify_algebra{}, simplify_algebra{},
......
...@@ -168,7 +168,7 @@ std::string make_transformer_args(std::vector<std::string> transformers) ...@@ -168,7 +168,7 @@ std::string make_transformer_args(std::vector<std::string> transformers)
return join_strings(std::move(transformers), ", "); return join_strings(std::move(transformers), ", ");
} }
std::string generate_pointwise(const module& pm, const std::string& name) void generate_pointwise(cpp_generator& gg, const module& pm, const std::string& name)
{ {
module m = pm; module m = pm;
run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}}); run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}});
...@@ -184,8 +184,131 @@ std::string generate_pointwise(const module& pm, const std::string& name) ...@@ -184,8 +184,131 @@ std::string generate_pointwise(const module& pm, const std::string& name)
// Add explict conversions // Add explict conversions
g.fresult( g.fresult(
[](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; }); [](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; });
g.create_function( gg.create_function(g.generate_module(m)
g.generate_module(m).set_attributes({"__device__"}).set_generic_types(m).set_name(name)); .set_attributes({"__device__", "__attribute__((const))"})
.set_generic_types(m)
.set_name(name));
}
std::string generate_pointwise(const module& pm, const std::string& name)
{
cpp_generator g;
generate_pointwise(g, pm, name);
return g.str();
}
std::string reduce_op::str() const
{
return write + "(r.reduce(" + reduction + ", " + init + ", " + read + ")(" + input + "))";
}
void reduce_op::set(instruction_ref ins, const operation& op)
{
if(op.name() == "reduce_sum")
{
reduction = "op::sum{}";
}
else if(op.name() == "reduce_mean")
{
auto s = ins->inputs().front()->get_shape();
auto reduce_elements = s.elements() / ins->get_shape().elements();
auto reduce_type = s.type();
reduction = "op::sum{}";
std::string mean = "op::mean<" + std::to_string(reduce_elements) + ">{}";
// Use float accumulator when reduction size is too large for half
if(reduce_type == shape::half_type and reduce_elements > 16384)
read = "compose(" + mean + ", op::convert_to<float>{})";
else if(contains({shape::float_type, shape::half_type, shape::double_type}, reduce_type))
read = mean;
else
write = mean;
}
else if(op.name() == "reduce_max")
{
reduction = "op::max{}";
init = "lowest{}";
}
else if(op.name() == "reduce_min")
{
reduction = "op::min{}";
init = "highest{}";
}
else if(op.name() == "reduce_prod")
{
reduction = "op::product{}";
init = "1";
}
else
{
MIGRAPHX_THROW("Unsupported reduce");
}
}
std::string reduce_op::generate(instruction_ref ins, const std::string& x)
{
reduce_op r{x};
r.set(ins, ins->get_operator());
return r.str();
}
static bool use_lazy_inner(instruction_ref ins)
{
if(ins->outputs().size() != 1)
return false;
auto output = ins->outputs().front();
return contains(output->name(), "reduce") or output->name() == "@return";
}
std::string generate_reduce(const module& m, const std::string& name)
{
cpp_generator g;
auto ilens = m.get_parameter_shapes().begin()->second.lens();
std::size_t i = 0;
auto f = g.generate_module(m, [&](instruction_ref ins, const auto& names) {
if(contains(ins->name(), "reduce"))
{
return reduce_op::generate(ins, names.at(ins->inputs().front()));
}
else if(ins->name() == "pointwise")
{
auto pointwise_name = "pointwise" + std::to_string(i);
i++;
generate_pointwise(g, *ins->module_inputs().front(), pointwise_name);
std::vector<instruction_ref> tensors;
std::copy_if(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(tensors),
[&](auto input) {
return input->get_shape().lens() == ilens and
not input->get_shape().broadcasted();
});
auto inner_names = names;
for(auto input : tensors)
inner_names[input] += "_lambda_param";
auto call_function =
pointwise_name + "(" +
join_strings(cpp_generator::to_args(ins->inputs(), inner_names), ", ") + ")";
if(tensors.empty())
return call_function;
const std::string inner_template =
"r.${inner}([=](${params}) { return ${call}; })(${args})";
std::string inner_name = use_lazy_inner(ins) ? "lazy_inner" : "inner";
auto args = cpp_generator::to_args(tensors, names);
auto params = cpp_generator::to_args(tensors, inner_names);
std::transform(
params.begin(), params.end(), params.begin(), [](auto s) { return "auto " + s; });
return interpolate_string(inner_template,
{{"inner", inner_name},
{"params", join_strings(params, ", ")},
{"args", join_strings(args, ", ")},
{"call", call_function}});
}
else if(ins->name() == "multibroadcast")
{
return names.at(ins->inputs().front());
}
MIGRAPHX_THROW("Unknown operator: " + ins->name());
});
f.set_attributes({"__device__", "__attribute__((const))"}).set_generic_types(m).set_name(name);
f.add_generic_param("r");
g.create_function(f);
return g.str(); return g.str();
} }
...@@ -196,8 +319,18 @@ static std::vector<std::string> get_op_names(const module& m) ...@@ -196,8 +319,18 @@ static std::vector<std::string> get_op_names(const module& m)
{ {
if(starts_with(ins.name(), "@")) if(starts_with(ins.name(), "@"))
continue; continue;
if(ins.name() == "multibroadcast")
continue;
if(ins.name() == "pointwise")
{
auto names = get_op_names(*ins.module_inputs().front());
result.insert(result.end(), names.begin(), names.end());
}
else
{
result.push_back(ins.name()); result.push_back(ins.name());
} }
}
return result; return result;
} }
......
...@@ -189,8 +189,20 @@ argument register_on_gpu(const argument& arg) ...@@ -189,8 +189,20 @@ argument register_on_gpu(const argument& arg)
argument to_gpu(const argument& arg, bool host) argument to_gpu(const argument& arg, bool host)
{ {
argument result;
arg.visit(
[&](auto x) {
auto p = write_to_gpu(arg.data(), arg.get_shape().bytes(), host); auto p = write_to_gpu(arg.data(), arg.get_shape().bytes(), host);
return {arg.get_shape(), p}; result = {x.get_shape(), p};
},
[&](const auto& xs) {
std::vector<argument> args;
std::transform(xs.begin(), xs.end(), std::back_inserter(args), [&](auto x) {
return to_gpu(x, host);
});
result = argument{args};
});
return result;
} }
argument from_gpu(const argument& arg) argument from_gpu(const argument& arg)
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/module_ref.hpp> #include <migraphx/module_ref.hpp>
#include <migraphx/instruction_ref.hpp>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -34,6 +35,7 @@ namespace migraphx { ...@@ -34,6 +35,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct shape; struct shape;
struct operation;
namespace gpu { namespace gpu {
...@@ -72,8 +74,23 @@ std::string make_transformer_args(Ts... xs) ...@@ -72,8 +74,23 @@ std::string make_transformer_args(Ts... xs)
std::string generate_pointwise(const module& pm, const std::string& name); std::string generate_pointwise(const module& pm, const std::string& name);
std::string generate_reduce(const module& m, const std::string& name);
std::string generate_name_from_ops(const module& m); std::string generate_name_from_ops(const module& m);
struct reduce_op
{
std::string input = "";
std::string reduction = "";
std::string init = "0";
std::string read = "op::id{}";
std::string write = "op::id{}";
void set(instruction_ref ins, const operation& op);
std::string str() const;
static std::string generate(instruction_ref ins, const std::string& x);
};
} // namespace gen } // namespace gen
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -71,6 +71,8 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option ...@@ -71,6 +71,8 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
std::size_t compute_block_size(std::size_t n, std::size_t max_block_size = 1024); std::size_t compute_block_size(std::size_t n, std::size_t max_block_size = 1024);
std::string generate_make_shape(const shape& s);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_RTGLIB_CONVOLUTION_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_GPU_CONVOLUTION_HPP
#define MIGRAPHX_GUARD_RTGLIB_CONVOLUTION_HPP #define MIGRAPHX_GUARD_RTGLIB_GPU_CONVOLUTION_HPP
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
......
...@@ -60,15 +60,6 @@ __global__ void reduce_kernel(void* input_p, void* output_p) ...@@ -60,15 +60,6 @@ __global__ void reduce_kernel(void* input_p, void* output_p)
)__migraphx__"; )__migraphx__";
static std::size_t get_reduce_elements(const std::vector<shape>& inputs)
{
return inputs.front().elements() / inputs.back().elements();
}
static std::size_t get_reduce_elements(const std::vector<instruction_ref>& inputs)
{
return get_reduce_elements(to_shapes(inputs));
}
static std::vector<std::size_t> get_reduce_lens(const std::vector<std::size_t>& input_lens, static std::vector<std::size_t> get_reduce_lens(const std::vector<std::size_t>& input_lens,
const std::vector<std::size_t>& output_lens) const std::vector<std::size_t>& output_lens)
{ {
...@@ -86,9 +77,28 @@ static std::vector<std::size_t> get_reduce_lens(const std::vector<std::size_t>& ...@@ -86,9 +77,28 @@ static std::vector<std::size_t> get_reduce_lens(const std::vector<std::size_t>&
return reduce_lens; return reduce_lens;
} }
static std::string get_reduce_algo(const std::vector<shape>& inputs) template <class T>
static shape get_reduced_shape(const shape& s, const std::vector<T>& axes)
{
auto lens = s.lens();
std::fill(lens.begin(), lens.end(), 1);
for(const auto& axis : axes)
lens[axis] = s.lens()[axis];
return shape{s.type(), lens};
}
template <class T>
static shape get_output_shape(const shape& s, const std::vector<T>& axes)
{
auto lens = s.lens();
for(const auto& axis : axes)
lens[axis] = 1;
return shape{s.type(), lens};
}
template <class ReduceLens>
static std::string get_reduce_algo(const std::vector<shape>& inputs, ReduceLens rlens)
{ {
auto rlens = get_reduce_lens(inputs.front().lens(), inputs.back().lens());
const auto init = std::numeric_limits<std::size_t>::max(); const auto init = std::numeric_limits<std::size_t>::max();
// The minimum stride // The minimum stride
auto min_stride = std::inner_product( auto min_stride = std::inner_product(
...@@ -103,11 +113,27 @@ static std::string get_reduce_algo(const std::vector<shape>& inputs) ...@@ -103,11 +113,27 @@ static std::string get_reduce_algo(const std::vector<shape>& inputs)
return "block"; return "block";
} }
struct reduce_compiler : compiler<reduce_compiler> static std::string get_reduce_algo(const std::vector<shape>& inputs)
{
auto rlens = get_reduce_lens(inputs.front().lens(), inputs.back().lens());
return get_reduce_algo(inputs, rlens);
}
struct simple_reduce_compiler : compiler<simple_reduce_compiler>
{ {
std::vector<std::string> names() const std::vector<std::string> names() const
{ {
return {"reduce", "reduce_sum", "reduce_mean", "reduce_max", "reduce_min", "reduce_prod"}; return {"simple_reduce",
"reduce_sum",
"reduce_mean",
"reduce_max",
"reduce_min",
"reduce_prod"};
}
static std::size_t get_reduce_elements(const std::vector<shape>& inputs)
{
return inputs.front().elements() / inputs.back().elements();
} }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
...@@ -157,44 +183,108 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -157,44 +183,108 @@ struct reduce_compiler : compiler<reduce_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{ {
value v = value::object{}; value v = value::object{};
if(op.name() == "reduce_sum") reduce_op r{};
{ r.set(ins, op);
v["reduction"] = "op::sum{}"; v["reduction"] = r.reduction;
} v["read"] = r.read;
else if(op.name() == "reduce_mean") v["write"] = r.write;
{ v["init"] = r.init;
auto reduce_elements = get_reduce_elements(ins->inputs()); return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
auto reduce_type = ins->inputs().front()->get_shape().type();
v["reduction"] = "op::sum{}";
std::string mean = "op::mean<" + std::to_string(reduce_elements) + ">{}";
// Use float accumulator when reduction size is too large for half
if(reduce_type == shape::half_type and reduce_elements > 16384)
v["read"] = "compose(" + mean + ", op::convert_to<float>{})";
else if(contains({shape::float_type, shape::half_type, shape::double_type},
reduce_type))
v["read"] = mean;
else
v["write"] = mean;
} }
else if(op.name() == "reduce_max") };
static const char* const fused_reduce_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
MIGRAPHX_GLOBAL void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, auto... xs) {
fused_reduce<reduce::${algo}, ${reduced}>(y, partial(${lambda})(xs...));
});
}
}
} // namespace migraphx
)__migraphx__";
struct fused_reduce_compiler : compiler<fused_reduce_compiler>
{
std::vector<std::string> names() const { return {"fused_reduce"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
v["reduction"] = "op::max{}"; auto axes = v.at("axes").to_vector<std::size_t>();
v["init"] = "lowest{}"; auto virtual_inputs = inputs;
} virtual_inputs.push_back(get_reduced_shape(inputs.front(), axes));
else if(op.name() == "reduce_min") virtual_inputs.push_back(get_output_shape(inputs.front(), axes));
virtual_inputs = reduce_dims(virtual_inputs);
auto reduce_output_shape = virtual_inputs.back();
virtual_inputs.pop_back();
auto reduction_shape = virtual_inputs.back();
virtual_inputs.pop_back();
hip_compile_options options;
options.inputs = inputs;
options.output = inputs.back();
options.virtual_inputs = virtual_inputs;
auto faxis = find_fast_axis({options.virtual_inputs.front()});
vectorize vec{};
auto nelements = reduce_output_shape.elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs, reduction_shape.lens()));
if(algo == "block")
{ {
v["reduction"] = "op::min{}"; // Vectorize if the axis is a reduction axis
v["init"] = "highest{}"; if(reduce_output_shape.lens()[faxis] == 1)
vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
auto relements = reduction_shape.elements() / vec.size;
auto block_size = compute_block_size(relements, 256);
if(relements >= block_size * 256)
algo = "block_large";
options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size);
} }
else if(op.name() == "reduce_prod") else if(algo == "lane")
{ {
v["reduction"] = "op::product{}"; options.set_launch_params(v, compute_global_for(ctx, nelements, 256));
v["init"] = "1";
} }
else else
{ {
MIGRAPHX_THROW("Unsupported reduce"); MIGRAPHX_THROW("Unknown reduce algo: " + algo);
}
options.kernel_name = v.get("kernel", "reduce_kernel");
auto src = interpolate_string(
fused_reduce_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"algo", algo},
{"reduced", "decltype(" + generate_make_shape(reduce_output_shape) + ")"},
{"lambda", v.at("lambda").to<std::string>()},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})}});
options.params += "-Wno-float-equal";
return compile_hip_code_object(src, options);
} }
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
assert(not ins->module_inputs().empty());
auto v = op.to_value();
auto* rm = ins->module_inputs().front();
v["preamble"] = generate_reduce(*rm, "fused_reduce_op");
v["lambda"] = "MIGRAPHX_LIFT(fused_reduce_op)";
v["kernel"] = generate_name_from_ops(*rm) + "_kernel";
return replace(compile_op(ctx, to_shapes(ins->inputs()), v)); return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment