Commit cd4ab535 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

manual merge

parents 3891ee58 a0fa3742
...@@ -74,13 +74,23 @@ struct shape_impl ...@@ -74,13 +74,23 @@ struct shape_impl
shape_impl(shape::type_t t, shape_impl(shape::type_t t,
std::vector<std::size_t> mins, std::vector<std::size_t> mins,
std::vector<std::size_t> maxes, std::vector<std::size_t> maxes,
std::vector<std::size_t> opts) std::vector<std::set<std::size_t>> optimals_list)
: m_type(t) : m_type(t)
{ {
assert(mins.size() == maxes.size() and maxes.size() == opts.size()); if(optimals_list.empty())
for(size_t i = 0; i < mins.size(); ++i)
{ {
m_dyn_dims.push_back(shape::dynamic_dimension{mins[i], maxes[i], opts[i]}); for(size_t i = 0; i < mins.size(); ++i)
{
m_dyn_dims.push_back(shape::dynamic_dimension{mins[i], maxes[i]});
}
}
else
{
assert(mins.size() == maxes.size() and maxes.size() == optimals_list.size());
for(size_t i = 0; i < mins.size(); ++i)
{
m_dyn_dims.push_back(shape::dynamic_dimension{mins[i], maxes[i], optimals_list[i]});
}
} }
} }
...@@ -147,7 +157,7 @@ struct shape_impl ...@@ -147,7 +157,7 @@ struct shape_impl
std::transform(m_dyn_dims.cbegin(), std::transform(m_dyn_dims.cbegin(),
m_dyn_dims.cend(), m_dyn_dims.cend(),
ret.begin(), ret.begin(),
[](shape::dynamic_dimension x) { return x.min; }); [](const shape::dynamic_dimension& x) { return x.min; });
return ret; return ret;
} }
...@@ -157,19 +167,20 @@ struct shape_impl ...@@ -157,19 +167,20 @@ struct shape_impl
std::transform(m_dyn_dims.cbegin(), std::transform(m_dyn_dims.cbegin(),
m_dyn_dims.cend(), m_dyn_dims.cend(),
ret.begin(), ret.begin(),
[](shape::dynamic_dimension x) { return x.max; }); [](const shape::dynamic_dimension& x) { return x.max; });
return ret; return ret;
} }
std::vector<std::size_t> opt_lens() const std::vector<std::set<std::size_t>> opt_lens() const
{ {
std::vector<std::size_t> ret(m_dyn_dims.size()); std::vector<std::set<std::size_t>> ret(m_dyn_dims.size());
std::transform(m_dyn_dims.cbegin(), std::transform(m_dyn_dims.cbegin(),
m_dyn_dims.cend(), m_dyn_dims.cend(),
ret.begin(), ret.begin(),
[](shape::dynamic_dimension x) { return x.opt; }); [](const shape::dynamic_dimension& x) { return x.optimals; });
return ret; return ret;
} }
// Does the shape skip over elements? // Does the shape skip over elements?
bool skips() const bool skips() const
{ {
...@@ -240,8 +251,9 @@ shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims) ...@@ -240,8 +251,9 @@ shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims)
shape::shape(type_t t, shape::shape(type_t t,
std::vector<std::size_t> mins, std::vector<std::size_t> mins,
std::vector<std::size_t> maxes, std::vector<std::size_t> maxes,
std::vector<std::size_t> opts) std::vector<std::set<std::size_t>> optimals_list)
: impl(std::make_shared<shape_impl>(t, std::move(mins), std::move(maxes), std::move(opts))) : impl(std::make_shared<shape_impl>(
t, std::move(mins), std::move(maxes), std::move(optimals_list)))
{ {
} }
...@@ -349,29 +361,26 @@ std::size_t shape::index(std::size_t i) const ...@@ -349,29 +361,26 @@ std::size_t shape::index(std::size_t i) const
} }
} }
std::vector<std::size_t> shape::multi(std::size_t i) const std::vector<std::size_t> shape::multi(std::size_t idx) const
{ {
assert(this->standard()); assert(idx < elements());
std::vector<std::size_t> indices(lens().size()); std::vector<std::size_t> indices(lens().size());
multi_copy(i, indices.data(), indices.data() + lens().size()); multi_copy(idx, indices.data(), indices.data() + lens().size());
return indices; return indices;
} }
void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const void shape::multi_copy(std::size_t idx, std::size_t* start, const std::size_t* end) const
{ {
assert(this->standard()); size_t tidx = idx;
(void)end; (void)end;
assert(idx < elements());
assert(lens().size() <= (end - start)); assert(lens().size() <= (end - start));
std::transform(strides().begin(), for(size_t ii = lens().size() - 1; ii > 0; ii--)
strides().end(), {
lens().begin(), *(start + ii) = tidx % lens()[ii];
start, tidx = tidx / lens()[ii];
[&](std::size_t stride, std::size_t len) { }
assert(len > 0 and stride > 0); *start = tidx;
return (i / stride) % len;
});
} }
bool shape::packed() const bool shape::packed() const
...@@ -469,12 +478,44 @@ shape shape::with_type(type_t t) const ...@@ -469,12 +478,44 @@ shape shape::with_type(type_t t) const
shape shape::to_dynamic() const shape shape::to_dynamic() const
{ {
if(not sub_shapes().empty())
{
std::vector<shape> subs;
std::transform(sub_shapes().cbegin(),
sub_shapes().cend(),
std::back_inserter(subs),
[](auto s) { return s.to_dynamic(); });
return {subs};
}
if(this->dynamic()) if(this->dynamic())
{ {
return *this; return *this;
} }
std::vector<std::size_t> zeroes(this->ndim(), 0); return {type(), lens(), lens(), {}};
return {type(), lens(), lens(), zeroes}; }
shape shape::to_static(std::size_t x) const
{
if(not sub_shapes().empty())
{
std::vector<shape> subs;
std::transform(sub_shapes().cbegin(),
sub_shapes().cend(),
std::back_inserter(subs),
[&](auto s) { return s.to_static(x); });
return {subs};
}
if(not this->dynamic())
{
return *this;
}
auto static_lens = this->max_lens();
std::transform(static_lens.begin(),
static_lens.end(),
this->dyn_dims().cbegin(),
static_lens.begin(),
[&](auto sl, auto dd) { return dd.is_fixed() ? sl : x; });
return {type(), static_lens};
} }
std::size_t shape::element_space() const { return impl->element_space(); } std::size_t shape::element_space() const { return impl->element_space(); }
...@@ -506,23 +547,22 @@ std::vector<std::size_t> shape::max_lens() const ...@@ -506,23 +547,22 @@ std::vector<std::size_t> shape::max_lens() const
return this->dynamic() ? impl->max_lens() : this->lens(); return this->dynamic() ? impl->max_lens() : this->lens();
} }
std::vector<std::size_t> shape::opt_lens() const std::vector<std::set<std::size_t>> shape::opt_lens() const { return impl->opt_lens(); }
{
return this->dynamic() ? impl->opt_lens() : this->lens();
}
bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max; } bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max; }
bool shape::dynamic_dimension::has_optimal() const { return opt != 0; } bool shape::dynamic_dimension::has_optimal() const { return not optimals.empty(); }
shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const std::size_t& x) shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const std::size_t& x)
{ {
this->min += x; this->min += x;
this->max += x; this->max += x;
if(this->opt != 0) std::set<std::size_t> new_optimals;
{ std::transform(this->optimals.begin(),
this->opt += x; this->optimals.end(),
}; std::inserter(new_optimals, new_optimals.begin()),
[&x](const auto& opt) { return (opt + x); });
this->optimals = new_optimals;
return *this; return *this;
} }
...@@ -532,19 +572,23 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const std::size_t ...@@ -532,19 +572,23 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const std::size_t
assert(this->max >= x); assert(this->max >= x);
this->min -= x; this->min -= x;
this->max -= x; this->max -= x;
if(this->opt != 0) std::set<std::size_t> new_optimals;
{ std::transform(this->optimals.begin(),
assert(this->opt >= x); this->optimals.end(),
this->opt -= x; std::inserter(new_optimals, new_optimals.begin()),
} [&x](const auto& opt) {
assert(opt >= x);
return (opt - x);
});
this->optimals = new_optimals;
return *this; return *this;
} }
bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{ {
// don't check opt if both are fixed // don't check optimals if both are fixed
return (x.min == y.min and x.max == y.max and return (x.min == y.min and x.max == y.max and
((x.is_fixed() and y.is_fixed()) or (x.opt == y.opt))); ((x.is_fixed() and y.is_fixed()) or (x.optimals == y.optimals)));
} }
bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
...@@ -553,7 +597,7 @@ bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimensio ...@@ -553,7 +597,7 @@ bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimensio
} }
std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x) std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x)
{ {
os << "[" << x.min << ", " << x.max << ", " << x.opt << "]"; os << "[ " << x.min << ", " << x.max << ", {" << migraphx::to_string_range(x.optimals) << "} ]";
return os; return os;
} }
...@@ -662,12 +706,10 @@ void migraphx_from_value(const value& v, shape& s) ...@@ -662,12 +706,10 @@ void migraphx_from_value(const value& v, shape& s)
{ {
auto v_dd = v.at("dynamic_dimensions"); auto v_dd = v.at("dynamic_dimensions");
std::vector<shape::dynamic_dimension> dyn_dims(v.at("dynamic_dimensions").size()); std::vector<shape::dynamic_dimension> dyn_dims(v.at("dynamic_dimensions").size());
std::transform(v_dd.begin(), v_dd.end(), dyn_dims.begin(), [](migraphx::value x) { std::transform(
auto x_min = x.at("min").template to<size_t>(); v_dd.begin(), v_dd.end(), dyn_dims.begin(), [](const migraphx::value& x) {
auto x_max = x.at("max").template to<size_t>(); return from_value<shape::dynamic_dimension>(x);
auto x_opt = x.at("opt").template to<size_t>(); });
return shape::dynamic_dimension{x_min, x_max, x_opt};
});
s = shape{shape::parse_type(t), dyn_dims}; s = shape{shape::parse_type(t), dyn_dims};
} }
......
...@@ -39,6 +39,8 @@ ...@@ -39,6 +39,8 @@
#include <migraphx/algorithm.hpp> #include <migraphx/algorithm.hpp>
#include <unordered_set> #include <unordered_set>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_SIMPLIFY_ALGEBRA_MATCHES)
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -52,8 +54,9 @@ auto op_lit_broadcast(std::string op, std::string x, std::string y) ...@@ -52,8 +54,9 @@ auto op_lit_broadcast(std::string op, std::string x, std::string y)
auto conv_const_weights() auto conv_const_weights()
{ {
return match::name("convolution")(match::used_once(), return match::name("convolution")(
match::args(match::any(), match::is_constant().bind("w"))); match::used_once(),
match::args(match::none_of(match::is_constant()), match::is_constant().bind("w")));
} }
auto reduction() { return match::name_contains("reduce"); } auto reduction() { return match::name_contains("reduce"); }
...@@ -203,7 +206,137 @@ struct find_mul_slice_conv ...@@ -203,7 +206,137 @@ struct find_mul_slice_conv
} }
}; };
// a * (x + b) => a * x + a * b struct find_mul_dot
{
auto matcher() const
{
auto is_dot_const_inputs =
match::name("dot")(match::any_of[match::inputs()](match::is_constant()));
return match::name("mul")(match::either_arg(0, 1)(
is_dot_const_inputs.bind("dot"), match::name("broadcast", "multibroadcast").bind("c")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto dot_ins = r.instructions["dot"];
auto a_ins = dot_ins->inputs()[0];
auto b_ins = dot_ins->inputs()[1];
auto c_ins = r.instructions["c"];
const auto& c_strides = c_ins->get_shape().strides();
// There should only be one stride that is not zero
if(std::count_if(c_strides.begin(), c_strides.end(), [](auto s) { return s != 0; }) > 1)
return;
auto add_mul_const = [&](instruction_ref x_ins) {
if(not x_ins->can_eval())
return m.end();
auto broadcast_v = c_ins->get_operator().to_value();
broadcast_v["out_lens"] = x_ins->get_shape().lens();
auto cb_ins =
m.insert_instruction(ins, make_op(c_ins->name(), broadcast_v), c_ins->inputs());
return m.insert_instruction(ins, make_op("mul"), x_ins, cb_ins);
};
if(c_strides.back() == 1)
{
b_ins = add_mul_const(b_ins);
}
else if(c_strides[c_strides.size() - 2] == 1)
{
a_ins = add_mul_const(a_ins);
}
else if(c_ins->get_shape().scalar())
{
if(a_ins->can_eval())
a_ins = add_mul_const(a_ins);
else
b_ins = add_mul_const(b_ins);
}
else
{
return;
}
if(contains({a_ins, b_ins}, m.end()))
return;
m.replace_instruction(ins, make_op("dot"), a_ins, b_ins);
}
};
struct find_dot_mul
{
auto matcher() const
{
auto const_broadcast = match::name("broadcast", "multibroadcast")(match::is_constant());
auto mul = match::name("mul")(
match::used_once(),
match::either_arg(0, 1)(const_broadcast.bind("d"),
match::none_of(match::is_constant()).bind("z")));
return match::name("dot")(match::either_arg(0, 1)(mul, match::is_constant().bind("c")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = ins->inputs()[0];
auto b_ins = ins->inputs()[1];
auto d_ins = r.instructions["d"];
auto c_ins = r.instructions["c"];
auto z_ins = r.instructions["z"];
const auto& d_strides = d_ins->get_shape().strides();
// There should only be one stride that is not zero
if(std::count_if(d_strides.begin(), d_strides.end(), [](auto s) { return s != 0; }) > 1)
return;
if(not d_ins->get_shape().scalar())
{
if(d_strides.back() == 1 and not b_ins->can_eval())
return;
if(d_strides[d_strides.size() - 2] == 1 and not a_ins->can_eval())
return;
}
auto broadcast_v = d_ins->get_operator().to_value();
auto c_lens = c_ins->get_shape().lens();
std::vector<int64_t> permutation(c_lens.size());
std::iota(permutation.begin(), permutation.end(), 0);
std::swap(permutation.back(), permutation[permutation.size() - 2]);
c_lens = reorder_dims(c_lens, permutation);
broadcast_v["out_lens"] = c_lens;
auto db_ins =
m.insert_instruction(ins, make_op(d_ins->name(), broadcast_v), d_ins->inputs());
auto db_transpose_ins =
m.insert_instruction(ins, make_op("transpose", {{"permutation", permutation}}), db_ins);
auto cd_ins = m.insert_instruction(ins, make_op("mul"), c_ins, db_transpose_ins);
if(c_ins == b_ins)
{
a_ins = z_ins;
b_ins = cd_ins;
}
else
{
a_ins = cd_ins;
b_ins = z_ins;
}
m.replace_instruction(ins, make_op("dot"), a_ins, b_ins);
}
};
// ******************************
// a * (x + b) => a * x + a * b
// ******************************
// When a * (x + b) is followed by another add of constant, then the
// additional add can be const folded. Also, better fusions can be applied
// when the add comes after.
struct find_mul_add struct find_mul_add
{ {
auto matcher() const auto matcher() const
...@@ -268,6 +401,32 @@ struct find_dot_add ...@@ -268,6 +401,32 @@ struct find_dot_add
} }
}; };
struct find_conv_add
{
auto matcher() const
{
auto add = match::name("add")(
match::either_arg(0, 1)(match::any().bind("x"),
match::any_of(match::is_constant()).bind("a")),
match::used_once());
return match::name("convolution")(match::used_once(),
match::args(add, match::is_constant().bind("w")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
auto x_ins = r.instructions["x"];
auto w_ins = r.instructions["w"];
auto conv1 = m.insert_instruction(ins, ins->get_operator(), a_ins, w_ins);
auto conv2 = m.insert_instruction(ins, ins->get_operator(), x_ins, w_ins);
m.replace_instruction(ins, make_op("add"), conv1, conv2);
}
};
struct find_add_lit_broadcast struct find_add_lit_broadcast
{ {
auto matcher() const auto matcher() const
...@@ -329,30 +488,123 @@ struct find_inner_broadcast ...@@ -329,30 +488,123 @@ struct find_inner_broadcast
{ {
auto matcher() const { return pointwise(match::all_of[match::inputs()](match::broadcast())); } auto matcher() const { return pointwise(match::all_of[match::inputs()](match::broadcast())); }
static auto non_scalar_op(const std::string& name)
{
return [=](instruction_ref ins) {
if(ins->get_shape().scalar())
return false;
return ins->name() == name;
};
}
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto broadcasts = ins->inputs(); auto broadcasts = ins->inputs();
if(broadcasts.empty()) if(broadcasts.empty())
return; return;
// Skip if different data types are used
if(any_of(broadcasts, [&](auto i) {
return i->get_shape().type() != broadcasts.front()->get_shape().type();
}))
return;
bool mixed_broadcasts = any_of(broadcasts, non_scalar_op("broadcast")) and
any_of(broadcasts, non_scalar_op("multibroadcast"));
// If the broadcast is not a single dimension, then dont perform inner_broadcast
if(mixed_broadcasts and any_of(broadcasts, [&](instruction_ref i) {
if(i->get_shape().scalar())
return false;
if(i->name() == "multibroadcast")
return false;
auto input = i->inputs().at(0);
const auto& lens = input->get_shape().lens();
return std::count_if(lens.begin(), lens.end(), [&](std::size_t d) {
return d == 1;
}) < (lens.size() - 1);
}))
return;
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
std::transform(broadcasts.begin(), std::transform(broadcasts.begin(),
broadcasts.end(), broadcasts.end(),
std::back_inserter(inputs), std::back_inserter(inputs),
[](auto i) { return i->inputs().front(); }); [&](instruction_ref i) {
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) { auto input = i->inputs().front();
return i->get_shape() != inputs.front()->get_shape() and if(mixed_broadcasts and not i->get_shape().scalar() and
i->get_shape().elements() != 1; i->get_shape().lens().size() > 1)
})) return m.insert_instruction(i, make_op("squeeze"), input);
return; return input;
});
auto b_it = std::find_if(broadcasts.begin(), broadcasts.end(), [&](auto i) {
return not i->get_shape().scalar(); std::sort(broadcasts.begin(), broadcasts.end(), by(std::less<>{}, [](instruction_ref i) {
}); if(i->get_shape().scalar())
if(b_it == broadcasts.end()) return 2;
b_it = broadcasts.begin(); else if(i->name() == "broadcast")
return 0;
if(i->name() == "multibroadcast")
return 1;
return 3;
}));
auto op = insert_common_op(m, ins, ins->get_operator(), inputs); auto op = insert_common_op(m, ins, ins->get_operator(), inputs);
m.replace_instruction(ins, (*b_it)->get_operator(), op); m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
}
};
struct find_dot_broadcast
{
auto matcher() const
{
return match::name("dot")(match::all_of[match::inputs()](match::broadcast()));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a = ins->inputs()[0];
auto b = ins->inputs()[1];
if(a->get_operator().name() != b->get_operator().name())
return;
if(ins->get_shape().lens().size() < 3)
return;
auto nbatch_axes = ins->get_shape().lens().size() - 2;
const auto& a_strides = a->get_shape().strides();
const auto& b_strides = b->get_shape().strides();
// Find leading batch axes that are broadcasted
auto p =
std::mismatch(a_strides.begin(),
a_strides.begin() + nbatch_axes,
b_strides.begin(),
b_strides.begin() + nbatch_axes,
[](auto astride, auto bstride) { return astride == 0 and bstride == 0; });
auto naxes = p.first - a_strides.begin();
assert(naxes <= nbatch_axes);
std::vector<std::size_t> axes(naxes);
std::iota(axes.begin(), axes.end(), 0);
auto insert_broadcast = [&](instruction_ref b_ins) -> instruction_ref {
auto input = b_ins->inputs()[0];
std::vector<std::size_t> lens(b_ins->get_shape().lens().begin() + naxes,
b_ins->get_shape().lens().end());
if(b_ins->name() == "multibroadcast")
{
return m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", lens}}), input);
}
else if(b_ins->name() == "broadcast")
{
auto v = b_ins->get_operator().to_value();
auto axis = v.at("axis").to<std::size_t>() - naxes;
return m.insert_instruction(
ins, make_op("broadcast", {{"axis", axis}, {"out_lens", lens}}), input);
}
assert(false);
return m.end();
};
auto a1 = insert_broadcast(a);
auto b1 = insert_broadcast(b);
auto dot = m.insert_instruction(ins, make_op("dot"), a1, b1);
auto broadcast = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", ins->get_shape().lens()}}), dot);
m.replace_instruction(ins, broadcast);
} }
}; };
...@@ -361,7 +613,8 @@ struct find_concat_op ...@@ -361,7 +613,8 @@ struct find_concat_op
auto matcher() const auto matcher() const
{ {
return match::name("concat")(match::any_of[match::inputs()]( return match::name("concat")(match::any_of[match::inputs()](
match::any_of(match::pointwise(), match::name("broadcast")), match::used_once())); match::any_of(match::pointwise(), match::name("broadcast", "multibroadcast")),
match::used_once()));
} }
template <class Iterator> template <class Iterator>
...@@ -380,7 +633,8 @@ struct find_concat_op ...@@ -380,7 +633,8 @@ struct find_concat_op
static bool is_valid_op(const operation& op) static bool is_valid_op(const operation& op)
{ {
return op.name() == "broadcast" or op.attributes().contains("pointwise"); return contains({"broadcast", "multibroadcast"}, op.name()) or
op.attributes().contains("pointwise");
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
...@@ -408,6 +662,16 @@ struct find_concat_op ...@@ -408,6 +662,16 @@ struct find_concat_op
op = b; op = b;
iaxis = 0; iaxis = 0;
} }
else if(op.name() == "multibroadcast")
{
shape bshape = (*start)->get_shape();
auto input = (*start)->inputs()[0];
if(iaxis >= bshape.strides().size() or bshape.strides()[iaxis] == 0)
return {start, last};
op.from_value({{"out_lens", get_output_lens(start, last, iaxis)}});
auto delta = bshape.lens().size() - input->get_shape().lens().size();
iaxis -= delta;
}
std::vector<instruction_ref> concats; std::vector<instruction_ref> concats;
for(std::size_t i = 0; i < x->inputs().size(); i++) for(std::size_t i = 0; i < x->inputs().size(); i++)
...@@ -1223,22 +1487,29 @@ struct find_split_transpose ...@@ -1223,22 +1487,29 @@ struct find_split_transpose
void simplify_algebra::apply(module& m) const void simplify_algebra::apply(module& m) const
{ {
size_t trace = value_of(MIGRAPHX_TRACE_SIMPLIFY_ALGEBRA_MATCHES{});
// Run simplifications multiple times // Run simplifications multiple times
for(int i = 0; i < 8; i++) for(int i = 0; i < 8; i++)
{ {
match::find_matches(m, match::find_matches(trace,
m,
find_inner_broadcast{}, find_inner_broadcast{},
find_dot_broadcast{},
find_double_add_lit_broadcast{}, find_double_add_lit_broadcast{},
find_add_lit_broadcast{}, find_add_lit_broadcast{},
find_add_convs{}, find_add_convs{},
find_conv_dot_horiz_fusion{}, find_conv_dot_horiz_fusion{},
find_mul_conv{}, find_mul_conv{},
find_mul_slice_conv{}, find_mul_slice_conv{},
find_mul_dot{},
find_dot_mul{},
find_mul_add{}, find_mul_add{},
find_unit_ops{}, find_unit_ops{},
find_neg_unit_ops{}, find_neg_unit_ops{},
find_zero_ops{}, find_zero_ops{},
find_dot_add{}, find_dot_add{},
find_conv_add{},
find_div_const{}, find_div_const{},
find_sub_const{}, find_sub_const{},
find_rsqrt{}, find_rsqrt{},
......
...@@ -762,7 +762,7 @@ struct find_transpose_slice ...@@ -762,7 +762,7 @@ struct find_transpose_slice
return; return;
// Compute axis before transpose to use for unsqueeze // Compute axis before transpose to use for unsqueeze
auto perm = ins->get_operator().to_value()["permutation"].to_vector<int64_t>(); auto perm = ins->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto preaxis = std::find(perm.begin(), perm.end(), axis) - perm.begin(); auto preaxis = perm[axis];
// Make unsqueeze // Make unsqueeze
std::vector<int64_t> steps(sdistance.size()); std::vector<int64_t> steps(sdistance.size());
std::transform( std::transform(
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/split_single_dyn_dim.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct dynamic_dimensions_check
{
std::string dyn_param_str;
size_t dyn_index;
size_t min_dim;
size_t max_dim;
};
optional<dynamic_dimensions_check>
has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes)
{
// True if parameters contain exactly one dynamic shape with exactly one non-fixed
// dynamic_dimension.
auto is_dynamic = [](const auto& p) { return p.second.dynamic(); };
auto ps_it = std::find_if(param_shapes.begin(), param_shapes.end(), is_dynamic);
if(ps_it == param_shapes.end())
return std::nullopt;
// Check if there is a second dynamic parameter
if(std::any_of(std::next(ps_it), param_shapes.end(), is_dynamic))
return std::nullopt;
const auto& dds = ps_it->second.dyn_dims();
auto is_non_fixed = [](const auto& dd) { return not dd.is_fixed(); };
auto dds_it = std::find_if(dds.begin(), dds.end(), is_non_fixed);
if(dds_it == dds.end())
return std::nullopt;
// Check if there is a second non-fixed dynamic_dimension
if(std::any_of(std::next(dds_it), dds.end(), is_non_fixed))
return std::nullopt;
return dynamic_dimensions_check{ps_it->first,
static_cast<std::size_t>(std::distance(dds.begin(), dds_it)),
dds_it->min,
dds_it->max};
}
namespace {
struct find_static_2in_broadcasts
{
// Convert 2 input static shape broadcast/multibroadcast into 1 input version.
// Some compiler passes (ex. simplify_algebra) only support the 1 input versions
// of the broadcasting operators.
auto matcher() const
{
return match::broadcast(match::nargs(2),
match::arg(0)(match::static_shape()),
match::arg(1)(match::static_shape()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto out_lens = ins->get_shape().lens();
auto broadcast_op = ins->get_operator();
if(broadcast_op.name() == "broadcast")
{
broadcast_op.from_value({{"out_lens", out_lens}});
}
else
{
broadcast_op.from_value({{"out_lens", out_lens}, {"out_dyn_dims", {}}});
}
m.replace_instruction(ins, broadcast_op, ins->inputs().at(0));
}
};
} // namespace
/**
* Makes all the shapes in the dynamic_dimension range. Probably won't work for `if`
* and `loop` instructions, depending on how the submodules for those
* work. Inserts select_module instruction to the top. Replaces return, bypassing other
* instructions. Skips if the dynamic parameter outputs to a select_module operator.
*/
void split_single_dyn_dim::apply(module_pass_manager& mpm) const
{
module_ref mm = &mpm.get_module();
auto param_names = mm->get_parameter_names();
auto param_shapes = mm->get_parameter_shapes();
optional<dynamic_dimensions_check> dd_check = has_one_dyn_dim(param_shapes);
auto any_sm_next = [&](auto ddc) {
auto p_outputs = mm->get_parameter(ddc->dyn_param_str)->outputs();
return std::any_of(p_outputs.cbegin(), p_outputs.cend(), [](auto ins) {
return ins->name() == "select_module";
});
};
if(dd_check.has_value() and not any_sm_next(dd_check))
{
const auto& dyn_param = mm->get_parameter(dd_check->dyn_param_str);
auto dyn_param_shape = mm->get_parameter_shape(dd_check->dyn_param_str);
std::vector<module_ref> submodules;
// create submodules for each dimension size
for(size_t dim_size : migraphx::range(dd_check->min_dim, dd_check->max_dim + 1))
{
auto* submod = mpm.create_module("dim_" + std::to_string(dim_size));
// instruction map for new static shaped submodule parameters
std::unordered_map<instruction_ref, instruction_ref> map_ins;
// create static shape using dim_size
auto static_lens = dyn_param_shape.max_lens();
static_lens.at(dd_check->dyn_index) = dim_size;
map_ins[dyn_param] = submod->add_parameter(
dd_check->dyn_param_str, migraphx::shape{dyn_param_shape.type(), static_lens});
auto outputs = submod->add_instructions(mm, map_ins);
submod->add_return({outputs});
match::find_matches(*submod, find_static_2in_broadcasts{});
submodules.push_back(submod);
}
// redirect to select_module operator and return
std::vector<instruction_ref> sm_inputs;
std::transform(param_names.cbegin(),
param_names.cend(),
std::back_inserter(sm_inputs),
[&](auto pn) { return mm->get_parameter(pn); });
auto output_shapes = mm->get_output_shapes();
migraphx::shape out_attr = migraphx::shape{output_shapes};
auto sm_ins = mm->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
sm_inputs,
submodules);
std::vector<instruction_ref> outputs(output_shapes.size());
for(size_t i = 0; i < output_shapes.size(); ++i)
{
outputs.at(i) =
mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", i}}), sm_ins);
}
mm->replace_return(outputs);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#define MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_PARALLEL_HPP #define MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_PARALLEL_HPP
// #define MIGRAPHX_DISABLE_OMP // #define MIGRAPHX_DISABLE_OMP
#include <cmath>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#ifdef MIGRAPHX_DISABLE_OMP #ifdef MIGRAPHX_DISABLE_OMP
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
......
...@@ -82,7 +82,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -82,7 +82,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
simplify_algebra{}, simplify_algebra{},
simplify_reshapes{}, simplify_reshapes{},
layout_nhwc{},
dead_code_elimination{}, dead_code_elimination{},
simplify_reshapes{}, simplify_reshapes{},
simplify_algebra{}, simplify_algebra{},
......
...@@ -41,7 +41,7 @@ class x_model ...@@ -41,7 +41,7 @@ class x_model
void set_shape(migraphx::shape); void set_shape(migraphx::shape);
}; };
x_model create_xmodel(migraphx::module_ref mod); x_model create_xmodel(migraphx::const_module_ref mod);
migraphx::argument execute(const x_model& xmodel, migraphx::argument execute(const x_model& xmodel,
const migraphx::shape& output_shape, const migraphx::shape& output_shape,
......
...@@ -113,8 +113,7 @@ void subgraph::apply(module_pass_manager& mpm) const ...@@ -113,8 +113,7 @@ void subgraph::apply(module_pass_manager& mpm) const
// TODO(varunsh): this code may be replaceable by code in the fuse_pointwise pass // TODO(varunsh): this code may be replaceable by code in the fuse_pointwise pass
// assuming all FPGA instructions are in one contiguous range // assuming all FPGA instructions are in one contiguous range
pm->insert_instructions(pm->end(), first, last, {}); pm->insert_instructions(pm->end(), first, std::next(last), {});
migraphx::instruction_ref placeholder_ins; migraphx::instruction_ref placeholder_ins;
for(auto it : iterator_for(mod)) for(auto it : iterator_for(mod))
{ {
......
...@@ -33,7 +33,7 @@ migraphx::shape x_model::get_shape() const { return shape; }; ...@@ -33,7 +33,7 @@ migraphx::shape x_model::get_shape() const { return shape; };
void x_model::set_shape(migraphx::shape s) { shape = s; } void x_model::set_shape(migraphx::shape s) { shape = s; }
x_model create_xmodel(const migraphx::module_ref mod) x_model create_xmodel(migraphx::const_module_ref mod)
{ {
std::cout << "Calling an external function: create_xmodel!\n"; std::cout << "Calling an external function: create_xmodel!\n";
x_model xmodel; x_model xmodel;
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
# THE SOFTWARE. # THE SOFTWARE.
# #################################################################################### # ####################################################################################
list(APPEND CMAKE_PREFIX_PATH /opt/rocm /opt/rocm/hip) list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
find_package(miopen) find_package(miopen)
# rocblas # rocblas
...@@ -33,7 +33,13 @@ if(NOT TARGET MIOpen) ...@@ -33,7 +33,13 @@ if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen") message(SEND_ERROR "Cant find miopen")
endif() endif()
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipRTC APIs") find_package(composable_kernel 1.0.0 COMPONENTS jit_library REQUIRED)
if(BUILD_DEV)
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipRTC APIs")
else()
set(MIGRAPHX_USE_HIPRTC ON CACHE BOOL "Use hipRTC APIs")
endif()
include(Embed) include(Embed)
file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS} file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS}
...@@ -91,6 +97,7 @@ add_library(migraphx_gpu ...@@ -91,6 +97,7 @@ add_library(migraphx_gpu
compile_miopen.cpp compile_miopen.cpp
compiler.cpp compiler.cpp
device_name.cpp device_name.cpp
fuse_ck.cpp
fuse_mlir.cpp fuse_mlir.cpp
fuse_ops.cpp fuse_ops.cpp
gather.cpp gather.cpp
...@@ -119,6 +126,7 @@ add_library(migraphx_gpu ...@@ -119,6 +126,7 @@ add_library(migraphx_gpu
schedule_model.cpp schedule_model.cpp
sync_device.cpp sync_device.cpp
target.cpp target.cpp
time_op.cpp
topk.cpp topk.cpp
write_literals.cpp write_literals.cpp
${JIT_GPU_SRCS} ${JIT_GPU_SRCS}
...@@ -237,9 +245,10 @@ else() ...@@ -237,9 +245,10 @@ else()
endif() endif()
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas) target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels) target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels composable_kernel::jit_library)
add_subdirectory(driver) add_subdirectory(driver)
add_subdirectory(hiprtc)
rocm_install_targets( rocm_install_targets(
TARGETS migraphx_gpu migraphx_device compile_for_gpu TARGETS migraphx_gpu migraphx_device compile_for_gpu
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp> #include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/cpp_generator.hpp> #include <migraphx/cpp_generator.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -168,10 +169,11 @@ std::string make_transformer_args(std::vector<std::string> transformers) ...@@ -168,10 +169,11 @@ std::string make_transformer_args(std::vector<std::string> transformers)
return join_strings(std::move(transformers), ", "); return join_strings(std::move(transformers), ", ");
} }
std::string generate_pointwise(const module& pm, const std::string& name) void generate_pointwise(cpp_generator& gg, const module& pm, const std::string& name)
{ {
module m = pm; module m = pm;
run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}}); run_passes(m,
{rewrite_quantization{}, eliminate_common_subexpression{}, dead_code_elimination{}});
cpp_generator g; cpp_generator g;
g.fmap([](const std::string& fname) { return "migraphx::" + fname; }); g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})"); g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
...@@ -184,8 +186,141 @@ std::string generate_pointwise(const module& pm, const std::string& name) ...@@ -184,8 +186,141 @@ std::string generate_pointwise(const module& pm, const std::string& name)
// Add explict conversions // Add explict conversions
g.fresult( g.fresult(
[](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; }); [](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; });
g.create_function( gg.create_function(g.generate_module(m)
g.generate_module(m).set_attributes({"__device__"}).set_generic_types(m).set_name(name)); .set_attributes({"__device__", "__attribute__((const))"})
.set_generic_types(m)
.set_name(name));
}
std::string generate_pointwise(const module& pm, const std::string& name)
{
cpp_generator g;
generate_pointwise(g, pm, name);
return g.str();
}
std::string reduce_op::str() const
{
return write + "(r.reduce(" + reduction + ", " + init + ", " + read + ")(" + input + "))";
}
void reduce_op::set(instruction_ref ins, const operation& op)
{
if(op.name() == "reduce_sum")
{
reduction = "op::sum{}";
}
else if(op.name() == "reduce_mean")
{
auto s = ins->inputs().front()->get_shape();
auto reduce_elements = s.elements() / ins->get_shape().elements();
auto reduce_type = s.type();
reduction = "op::sum{}";
std::string mean = "op::mean<" + std::to_string(reduce_elements) + ">{}";
// Use float accumulator when reduction size is too large for half
if(reduce_type == shape::half_type and reduce_elements > 16384)
read = "compose(" + mean + ", op::convert_to<float>{})";
else if(contains({shape::float_type, shape::half_type, shape::double_type}, reduce_type))
read = mean;
else
write = mean;
}
else if(op.name() == "reduce_max")
{
reduction = "op::max{}";
init = "lowest{}";
}
else if(op.name() == "reduce_min")
{
reduction = "op::min{}";
init = "highest{}";
}
else if(op.name() == "reduce_prod")
{
reduction = "op::product{}";
init = "1";
}
else
{
MIGRAPHX_THROW("Unsupported reduce");
}
}
std::string reduce_op::generate(instruction_ref ins, const std::string& x)
{
reduce_op r{x};
r.set(ins, ins->get_operator());
return r.str();
}
static bool use_lazy_inner(instruction_ref ins)
{
if(ins->outputs().size() != 1)
return false;
auto output = ins->outputs().front();
return contains(output->name(), "reduce") or output->name() == "@return";
}
std::string generate_reduce(const module& m, const std::string& name)
{
cpp_generator g;
auto ilens = m.get_parameter_shapes().begin()->second.lens();
std::size_t i = 0;
auto f = g.generate_module(m, [&](instruction_ref ins, const auto& names) {
if(contains(ins->name(), "reduce"))
{
return reduce_op::generate(ins, names.at(ins->inputs().front()));
}
else if(ins->name() == "pointwise")
{
auto pointwise_name = "pointwise" + std::to_string(i);
i++;
generate_pointwise(g, *ins->module_inputs().front(), pointwise_name);
std::vector<instruction_ref> tensors;
std::copy_if(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(tensors),
[&](auto input) {
return input->get_shape().lens() == ilens and
not input->get_shape().broadcasted();
});
auto inner_names = names;
for(auto input : ins->inputs())
{
if(input->name() != "@param")
continue;
if(contains(tensors, input))
continue;
inner_names[input] += "[out_idx]";
}
for(auto input : tensors)
inner_names[input] += "_lambda_param";
auto call_function =
pointwise_name + "(" +
join_strings(cpp_generator::to_args(ins->inputs(), inner_names), ", ") + ")";
if(tensors.empty())
return call_function;
const std::string inner_template =
"r.${inner}([=](${params}) { return ${call}; })(${args})";
std::string inner_name = use_lazy_inner(ins) ? "lazy_inner" : "inner";
auto args = cpp_generator::to_args(tensors, names);
auto params = cpp_generator::to_args(tensors, inner_names);
std::transform(
params.begin(), params.end(), params.begin(), [](auto s) { return "auto " + s; });
return interpolate_string(inner_template,
{{"inner", inner_name},
{"params", join_strings(params, ", ")},
{"args", join_strings(args, ", ")},
{"call", call_function}});
}
else if(ins->name() == "multibroadcast")
{
return names.at(ins->inputs().front());
}
MIGRAPHX_THROW("Unknown operator: " + ins->name());
});
f.set_attributes({"__device__", "__attribute__((const))"}).set_generic_types(m).set_name(name);
f.add_generic_param("r");
f.add_generic_param("out_idx");
f.unused_param("out_idx");
g.create_function(f);
return g.str(); return g.str();
} }
...@@ -196,7 +331,17 @@ static std::vector<std::string> get_op_names(const module& m) ...@@ -196,7 +331,17 @@ static std::vector<std::string> get_op_names(const module& m)
{ {
if(starts_with(ins.name(), "@")) if(starts_with(ins.name(), "@"))
continue; continue;
result.push_back(ins.name()); if(ins.name() == "multibroadcast")
continue;
if(ins.name() == "pointwise")
{
auto names = get_op_names(*ins.module_inputs().front());
result.insert(result.end(), names.begin(), names.end());
}
else
{
result.push_back(ins.name());
}
} }
return result; return result;
} }
......
...@@ -32,6 +32,13 @@ ...@@ -32,6 +32,13 @@
#ifdef MIGRAPHX_USE_HIPRTC #ifdef MIGRAPHX_USE_HIPRTC
#include <hip/hiprtc.h> #include <hip/hiprtc.h>
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/value.hpp>
#include <migraphx/tmp_dir.hpp>
#include <migraphx/dynamic_loader.hpp>
#include <migraphx/process.hpp>
#include <migraphx/msgpack.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/file_buffer.hpp>
#else #else
#include <migraphx/compile_src.hpp> #include <migraphx/compile_src.hpp>
#include <migraphx/process.hpp> #include <migraphx/process.hpp>
...@@ -49,9 +56,6 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC); ...@@ -49,9 +56,6 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC);
#ifdef MIGRAPHX_USE_HIPRTC #ifdef MIGRAPHX_USE_HIPRTC
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_HIPRTC);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS);
std::string hiprtc_error(hiprtcResult err, const std::string& msg) std::string hiprtc_error(hiprtcResult err, const std::string& msg)
{ {
return "hiprtc: " + (hiprtcGetErrorString(err) + (": " + msg)); return "hiprtc: " + (hiprtcGetErrorString(err) + (": " + msg));
...@@ -63,6 +67,7 @@ void hiprtc_check_error(hiprtcResult err, const std::string& msg, const std::str ...@@ -63,6 +67,7 @@ void hiprtc_check_error(hiprtcResult err, const std::string& msg, const std::str
throw make_exception(ctx, hiprtc_error(err, msg)); throw make_exception(ctx, hiprtc_error(err, msg));
} }
// NOLINTNEXTLINE
#define MIGRAPHX_HIPRTC(...) \ #define MIGRAPHX_HIPRTC(...) \
hiprtc_check_error(__VA_ARGS__, #__VA_ARGS__, MIGRAPHX_MAKE_SOURCE_CTX()) hiprtc_check_error(__VA_ARGS__, #__VA_ARGS__, MIGRAPHX_MAKE_SOURCE_CTX())
...@@ -110,21 +115,19 @@ struct hiprtc_program ...@@ -110,21 +115,19 @@ struct hiprtc_program
std::string cpp_src = ""; std::string cpp_src = "";
std::string cpp_name = ""; std::string cpp_name = "";
hiprtc_program(const std::vector<src_file>& srcs) hiprtc_program(std::vector<hiprtc_src_file> srcs)
{ {
for(auto&& src : srcs) for(auto&& src : srcs)
{ {
std::string content{src.content.first, src.content.second}; if(ends_with(src.path, ".cpp"))
std::string path = src.path.string();
if(src.path.extension().string() == ".cpp")
{ {
cpp_src = std::move(content); cpp_src = std::move(src.content);
cpp_name = std::move(path); cpp_name = std::move(src.path);
} }
else else
{ {
headers.push_back(std::move(content)); headers.push_back(std::move(src.content));
include_names.push_back(std::move(path)); include_names.push_back(std::move(src.path));
} }
} }
prog = hiprtc_program_create(cpp_src.c_str(), prog = hiprtc_program_create(cpp_src.c_str(),
...@@ -134,7 +137,7 @@ struct hiprtc_program ...@@ -134,7 +137,7 @@ struct hiprtc_program
include_names.data()); include_names.data());
} }
void compile(const std::vector<std::string>& options) void compile(const std::vector<std::string>& options) const
{ {
if(enabled(MIGRAPHX_TRACE_HIPRTC{})) if(enabled(MIGRAPHX_TRACE_HIPRTC{}))
std::cout << "hiprtc " << join_strings(options, " ") << " " << cpp_name << std::endl; std::cout << "hiprtc " << join_strings(options, " ") << " " << cpp_name << std::endl;
...@@ -175,10 +178,11 @@ struct hiprtc_program ...@@ -175,10 +178,11 @@ struct hiprtc_program
} }
}; };
std::vector<std::vector<char>> std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_src_file> srcs,
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch) std::string params,
const std::string& arch)
{ {
hiprtc_program prog(srcs); hiprtc_program prog(std::move(srcs));
auto options = split_string(params, ' '); auto options = split_string(params, ' ');
options.push_back("-DMIGRAPHX_USE_HIPRTC=1"); options.push_back("-DMIGRAPHX_USE_HIPRTC=1");
// remove following three compilation flags for HIPRTC once fixes from hipRTC are available in // remove following three compilation flags for HIPRTC once fixes from hipRTC are available in
...@@ -187,6 +191,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -187,6 +191,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
options.push_back("-DMIGRAPHX_HAS_DPP=0"); options.push_back("-DMIGRAPHX_HAS_DPP=0");
options.push_back("-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1"); options.push_back("-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1");
options.push_back("-Wno-reserved-identifier"); options.push_back("-Wno-reserved-identifier");
options.push_back("-Wno-unused-parameter");
options.push_back("-Wno-gnu-line-marker"); options.push_back("-Wno-gnu-line-marker");
options.push_back("-Wno-old-style-cast"); options.push_back("-Wno-old-style-cast");
} }
...@@ -205,8 +210,50 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -205,8 +210,50 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
return {prog.get_code_obj()}; return {prog.get_code_obj()};
} }
std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch)
{
std::vector<hiprtc_src_file> hsrcs{srcs.begin(), srcs.end()};
if(enabled(MIGRAPHX_GPU_DUMP_SRC{}))
{
for(const auto& src : srcs)
{
if(src.path.extension() != ".cpp")
continue;
std::cout << std::string(src.content.first, src.len()) << std::endl;
}
}
auto p = dynamic_loader::path(&compile_hip_src_with_hiprtc);
auto driver = p.parent_path().parent_path() / "bin" / "migraphx-hiprtc-driver";
if(fs::exists(driver))
{
value v;
v["srcs"] = to_value(hsrcs);
v["params"] = to_value(params);
v["arch"] = to_value(arch);
tmp_dir td{};
auto out = td.path / "output";
process(driver.string() + " " + out.string()).write([&](auto writer) {
to_msgpack(v, writer);
});
if(fs::exists(out))
return {read_buffer(out.string())};
}
return compile_hip_src_with_hiprtc(std::move(hsrcs), std::move(params), arch);
}
#else // MIGRAPHX_USE_HIPRTC #else // MIGRAPHX_USE_HIPRTC
std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_src_file>, // NOLINT
std::string, // NOLINT
const std::string&)
{
MIGRAPHX_THROW("Not using hiprtc");
}
bool is_hip_clang_compiler() bool is_hip_clang_compiler()
{ {
static const auto result = ends_with(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER), "clang++"); static const auto result = ends_with(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER), "clang++");
......
...@@ -135,10 +135,15 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over) ...@@ -135,10 +135,15 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
std::size_t max_global = ctx.get_current_device().get_cu_count() * std::size_t max_global = ctx.get_current_device().get_cu_count() *
ctx.get_current_device().get_max_workitems_per_cu(); ctx.get_current_device().get_max_workitems_per_cu();
return [n, over, max_global](std::size_t local) { return [n, over, max_global](std::size_t local) {
std::size_t groups = (n + local - 1) / local; std::size_t num_elements = n;
std::size_t max_blocks = max_global / local; std::size_t groups = (num_elements + local - 1) / local;
std::size_t nglobal = std::min(max_blocks * over, groups) * local; std::size_t max_blocks = max_global / local;
return std::min(nglobal, n); std::size_t nglobal = std::min(max_blocks * over, groups) * local;
#ifdef MIGRAPHX_USE_HIPRTC
if(enabled(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS{}))
num_elements = ((num_elements + local - 1) / local) * local;
#endif
return std::min(nglobal, num_elements);
}; };
} }
...@@ -156,7 +161,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option ...@@ -156,7 +161,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
assert(not options.inputs.empty()); assert(not options.inputs.empty());
assert(options.inputs.size() == options.virtual_inputs.size() or assert(options.inputs.size() == options.virtual_inputs.size() or
options.virtual_inputs.empty()); options.virtual_inputs.empty());
std::vector<src_file> srcs; std::vector<src_file> srcs = options.additional_src_files;
std::transform(migraphx_kernels().begin(), std::transform(migraphx_kernels().begin(),
migraphx_kernels().end(), migraphx_kernels().end(),
std::back_inserter(srcs), std::back_inserter(srcs),
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/gpu/compiler.hpp> #include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/time_op.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -76,6 +77,109 @@ struct compiled_result ...@@ -76,6 +77,109 @@ struct compiled_result
instruction_ref ins; instruction_ref ins;
}; };
struct problem_cache
{
bool has(const std::string& name, const value& problem) const
{
return contains(cache, create_key(name, problem));
}
void insert(const std::string& name, const value& problem, const value& solution)
{
assert(not solution.is_null());
cache[create_key(name, problem)] = solution;
}
void mark(const std::string& name, const value& problem)
{
cache.insert(std::make_pair(create_key(name, problem), value{}));
}
optional<value> get(const std::string& name, const value& problem) const
{
auto it = cache.find(create_key(name, problem));
if(it == cache.end())
return nullopt;
return it->second;
}
static value create_key(const std::string& name, const value& problem)
{
return {{"name", name}, {"problem", problem}};
}
std::unordered_map<value, value> cache;
};
struct compile_plan
{
context* ctx;
operation preop;
instruction_ref ins;
optional<tuning_config> config = nullopt;
std::vector<compiled_result> results = {};
void update_config() { config = get_tuning_config(*ctx, ins, preop); }
template <class Vector>
void add_compiles(Vector& compiles, problem_cache& pc)
{
if(config.has_value())
{
const auto& problem = config->problem;
if(auto sol = pc.get(preop.name(), problem))
{
auto solution = sol.value();
// No solution yet until benchmarked so skip for now
if(solution.is_null())
return;
results.resize(1);
compiles.emplace_back([=] {
results[0] = compiled_result{compile(*ctx, ins, preop, solution), ins};
});
}
else
{
pc.mark(preop.name(), problem);
const auto& solutions = config->solutions;
results.resize(solutions.size());
for(auto i : range(solutions.size()))
{
auto solution = solutions[i];
compiles.emplace_back([=] {
results[i] = compiled_result{compile(*ctx, ins, preop, solution), ins};
});
}
}
}
else
{
results.resize(1);
compiles.emplace_back([=] {
results[0] = compiled_result{compile(*ctx, ins, preop, value{}), ins};
});
}
}
const compiled_result& benchmark(problem_cache& pc) const
{
if(results.empty())
MIGRAPHX_THROW("No configs to tune");
if(results.size() == 1)
return results.front();
if(not config)
MIGRAPHX_THROW("Multiple kernels without config");
std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs"
<< std::endl;
std::vector<double> times;
times.reserve(results.size());
std::transform(
results.begin(), results.end(), std::back_inserter(times), [&](const auto& cr) {
return time_op(*ctx, cr.replace.code_object, to_shapes(cr.ins->inputs()), 20).first;
});
auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end()));
pc.insert(preop.name(), config->problem, config->solutions.at(i));
return results[i];
}
void replace(module& m, problem_cache& pc) const
{
const auto& cr = benchmark(pc);
cr.replace.replace(m, cr.ins);
}
};
template <class F> template <class F>
void par_compile(std::size_t n, F f) void par_compile(std::size_t n, F f)
{ {
...@@ -84,25 +188,67 @@ void par_compile(std::size_t n, F f) ...@@ -84,25 +188,67 @@ void par_compile(std::size_t n, F f)
par_for(n, n / value_of(MIGRAPHX_GPU_COMPILE_PARALLEL{}, n), f); par_for(n, n / value_of(MIGRAPHX_GPU_COMPILE_PARALLEL{}, n), f);
} }
void compile_ops::apply(module& m) const struct compile_manager
{ {
std::vector<std::function<compiled_result()>> compiles; problem_cache pc;
std::vector<compile_plan> cps;
bool exhaustive = false;
template <class... Ts>
void add_plan(Ts&&... xs)
{
cps.push_back({std::forward<Ts>(xs)...});
}
void update_configs()
{
if(not exhaustive)
return;
par_compile(cps.size(), [&](auto i) { cps[i].update_config(); });
}
void compile(module& m)
{
std::vector<std::function<void()>> compiles;
for(auto& cp : cps)
{
cp.add_compiles(compiles, pc);
}
par_compile(compiles.size(), [&](auto i) { compiles[i](); });
// Replace and/or benchmark
for(const auto& cp : cps)
{
if(cp.results.empty())
continue;
cp.replace(m, pc);
}
// Remove compile_plan already executed
cps.erase(std::remove_if(cps.begin(),
cps.end(),
[](const auto& cp) { return not cp.results.empty(); }),
cps.end());
}
};
void compile_ops::apply(module& m) const
{
compile_manager cm;
cm.exhaustive = exhaustive_tune;
// Find all precompile opes
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if(ins->name() != "gpu::precompile_op") if(ins->name() != "gpu::precompile_op")
continue; continue;
operation preop = any_cast<precompile_op>(ins->get_operator()).op; operation preop = any_cast<precompile_op>(ins->get_operator()).op;
compiles.emplace_back([=]() -> compiled_result { cm.add_plan(ctx, preop, ins);
return {compile(*ctx, ins, preop), ins};
});
}
std::vector<compiled_result> results(compiles.size());
par_compile(compiles.size(), [&](auto i) { results[i] = compiles[i](); });
for(const auto& cr : results)
{
cr.replace(m, cr.ins);
} }
cm.update_configs();
cm.compile(m);
// Compile already tuned configs
cm.compile(m);
assert(cm.cps.empty());
} }
} // namespace gpu } // namespace gpu
......
...@@ -28,33 +28,44 @@ namespace migraphx { ...@@ -28,33 +28,44 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
auto& compiler_map() namespace {
struct compiler_handle
{ {
static std::unordered_map<std::string, compiler_compile> m; // NOLINT compiler_compile compile;
return m; compiler_compile_op compile_op;
} compiler_tuning_config get_tuning_config;
};
} // namespace
auto& compiler_op_map() auto& compiler_map()
{ {
static std::unordered_map<std::string, compiler_compile_op> m; // NOLINT static std::unordered_map<std::string, compiler_handle> m; // NOLINT
return m; return m;
} }
void register_compiler(const std::string& name, compiler_compile c, compiler_compile_op cop) void register_compiler(const std::string& name,
compiler_compile c,
compiler_compile_op cop,
compiler_tuning_config ctg)
{ {
compiler_map()[name] = std::move(c); compiler_map()[name] = {std::move(c), std::move(cop), std::move(ctg)};
compiler_op_map()[name] = std::move(cop);
} }
bool has_compiler_for(const std::string& name) { return compiler_map().count(name) > 0; } bool has_compiler_for(const std::string& name) { return compiler_map().count(name) > 0; }
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) compiler_replace
compile(context& ctx, instruction_ref ins, const operation& op, const value& solution)
{ {
return compiler_map().at(op.name())(ctx, ins, op); return compiler_map().at(op.name()).compile(ctx, ins, op, solution);
} }
operation operation
compile_op(const std::string& name, context& ctx, const std::vector<shape>& inputs, const value& v) compile_op(const std::string& name, context& ctx, const std::vector<shape>& inputs, const value& v)
{ {
return compiler_op_map().at(name)(ctx, inputs, v); return compiler_map().at(name).compile_op(ctx, inputs, v);
}
optional<tuning_config> get_tuning_config(context& ctx, instruction_ref ins, const operation& op)
{
return compiler_map().at(op.name()).get_tuning_config(ctx, ins, op);
} }
} // namespace gpu } // namespace gpu
......
...@@ -94,6 +94,10 @@ template <> ...@@ -94,6 +94,10 @@ template <>
struct is_hip_type<std::uint8_t> : std::true_type struct is_hip_type<std::uint8_t> : std::true_type
{ {
}; };
template <>
struct is_hip_type<std::int32_t> : std::true_type
{
};
template <class T, class V, MIGRAPHX_REQUIRES(is_hip_type<typename T::type>{})> template <class T, class V, MIGRAPHX_REQUIRES(is_hip_type<typename T::type>{})>
void hip_visitor_invoke(T as, V&& v) void hip_visitor_invoke(T as, V&& v)
...@@ -120,12 +124,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs) ...@@ -120,12 +124,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
if(not std::all_of( if(not std::all_of(
types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); })) types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); }))
MIGRAPHX_THROW("Types must be the same"); MIGRAPHX_THROW("Types must be the same");
std::initializer_list<index_int> ranks = { std::initializer_list<index_int> ranks = {static_cast<index_int>(get_shape(xs).ndim())...};
static_cast<index_int>(get_shape(xs).lens().size())...}; if(not std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.ndim(); }))
if(not std::all_of(
ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same"); MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), [&](auto ndim) { visit_tensor_size(s.ndim(), [&](auto ndim) {
s.visit_type(hip_visitor([&](auto as) { v(f(xs, ndim, as)...); })); s.visit_type(hip_visitor([&](auto as) { v(f(xs, ndim, as)...); }));
}); });
} }
...@@ -133,12 +135,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs) ...@@ -133,12 +135,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
template <class V, class F, class... Ts> template <class V, class F, class... Ts>
void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs) void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs)
{ {
std::initializer_list<index_int> ranks = { std::initializer_list<index_int> ranks = {static_cast<index_int>(get_shape(xs).ndim())...};
static_cast<index_int>(get_shape(xs).lens().size())...}; if(not std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.ndim(); }))
if(not std::all_of(
ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same"); MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), [&](auto ndim) { v(f(xs, ndim)...); }); visit_tensor_size(s.ndim(), [&](auto ndim) { v(f(xs, ndim)...); });
} }
template <class F> template <class F>
......
...@@ -67,18 +67,19 @@ void multinomial(hipStream_t stream, ...@@ -67,18 +67,19 @@ void multinomial(hipStream_t stream,
size_t class_size = arg0.get_shape().lens().back(); size_t class_size = arg0.get_shape().lens().back();
size_t sample_size = result.get_shape().lens().back(); size_t sample_size = result.get_shape().lens().back();
hip_visit_all(arg0, arg1)([&](auto cdf, auto dist) { visit_all(arg0, arg1)([&](auto cdf_host, auto dist_host) {
result.visit([&](auto out) { result.visit([&](auto output_host) {
hip_visit_views(out)([&](auto output) { hip_visit_views(cdf_host, dist_host, output_host)(
gs_launch(stream, batch_size * sample_size)([=](auto i) __device__ { [&](auto cdf, auto dist, auto output) {
auto idx = output.get_shape().multi(i); gs_launch(stream, batch_size * sample_size)([=](auto i) __device__ {
auto cdf_begin = cdf.begin() + (idx.front() * class_size); auto idx = output.get_shape().multi(i);
auto cdf_end = cdf_begin + class_size; auto cdf_begin = cdf.begin() + (idx.front() * class_size);
auto sample_iter = auto cdf_end = cdf_begin + class_size;
upper_bound(cdf_begin, cdf_end, dist[i] * *(std::prev(cdf_end))); auto* sample_iter =
output[i] = std::distance(cdf_begin, sample_iter); upper_bound(cdf_begin, cdf_end, dist[i] * *(std::prev(cdf_end)));
output[i] = std::distance(cdf_begin, sample_iter);
});
}); });
});
}); });
}); });
} }
......
...@@ -37,22 +37,26 @@ argument scatter( ...@@ -37,22 +37,26 @@ argument scatter(
hipStream_t stream, argument result, argument arg0, argument arg1, argument arg2, int64_t axis) hipStream_t stream, argument result, argument arg0, argument arg1, argument arg2, int64_t axis)
{ {
auto ds = arg0.get_shape(); auto ds = arg0.get_shape();
auto inds = arg1.get_shape(); auto s1 = arg1.get_shape();
auto axis_dim_size = ds.lens()[axis]; auto axis_dim_size = ds.lens()[axis];
hip_visit_all(result, arg0, inds)([&](auto output, auto data, auto s1) { hip_visit_all(result, arg0, arg2)([&](auto output, auto data, auto update) {
auto* output_ptr = device_cast(output.data()); auto* output_ptr = device_cast(output.data());
const auto* data_ptr = device_cast(data.data()); const auto* data_ptr = device_cast(data.data());
gs_launch(stream, ds.elements())([=](auto i) __device__ { output_ptr[i] = data_ptr[i]; }); gs_launch(stream, ds.elements())([=](auto i) __device__ { output_ptr[i] = data_ptr[i]; });
hip_visit_all(arg1, arg2)([&](auto indices, auto update) {
const auto* upd_ptr = device_cast(update.data()); hip_visit_all(arg1)([&](auto indices) {
const auto* indices_ptr = device_cast(indices.data()); if constexpr(indices.get_shape().lens.size() == output.get_shape().lens.size())
gs_launch(stream, inds.elements())([=](auto i) __device__ { {
auto out_idx = s1.multi(i); const auto* upd_ptr = device_cast(update.data());
auto index = indices_ptr[i]; const auto* indices_ptr = device_cast(indices.data());
index = index < 0 ? index + axis_dim_size : index; gs_launch(stream, s1.elements())([=](auto i) __device__ {
out_idx[axis] = index; auto out_idx = indices.get_shape().multi(i);
output[out_idx] = upd_ptr[i]; auto index = indices_ptr[i];
}); index = index < 0 ? index + axis_dim_size : index;
out_idx[axis] = index;
output[out_idx] = upd_ptr[i];
});
}
}); });
}); });
......
...@@ -43,6 +43,8 @@ auto get_arch_name(rank<1>, const HipDeviceProp& props) -> decltype(std::string( ...@@ -43,6 +43,8 @@ auto get_arch_name(rank<1>, const HipDeviceProp& props) -> decltype(std::string(
return std::string(props.gcnArchName); return std::string(props.gcnArchName);
} }
std::string get_arch_name(const hipDeviceProp_t& props) { return get_arch_name(rank<1>{}, props); }
int get_device_id() int get_device_id()
{ {
int device; int device;
...@@ -58,7 +60,7 @@ std::string get_device_name() ...@@ -58,7 +60,7 @@ std::string get_device_name()
auto status = hipGetDeviceProperties(&props, get_device_id()); auto status = hipGetDeviceProperties(&props, get_device_id());
if(status != hipSuccess) if(status != hipSuccess)
MIGRAPHX_THROW("Failed to get device properties"); MIGRAPHX_THROW("Failed to get device properties");
return get_arch_name(rank<1>{}, props); return get_arch_name(props);
} }
} // namespace gpu } // namespace gpu
......
...@@ -26,5 +26,6 @@ file(GLOB GPU_DRIVER_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp ...@@ -26,5 +26,6 @@ file(GLOB GPU_DRIVER_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp
add_executable(gpu-driver add_executable(gpu-driver
${GPU_DRIVER_SRCS} ${GPU_DRIVER_SRCS}
) )
rocm_clang_tidy_check(gpu-driver)
target_include_directories(gpu-driver PRIVATE include) target_include_directories(gpu-driver PRIVATE include)
target_link_libraries(gpu-driver PRIVATE migraphx_gpu) target_link_libraries(gpu-driver PRIVATE migraphx_gpu)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment