Commit 665455e1 authored by Paul's avatar Paul
Browse files

Format

parent 69de5fdb
......@@ -20,30 +20,31 @@ static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
return it;
}
template<class Iterator>
template <class Iterator>
static auto elements(Iterator start, Iterator last)
{
return std::accumulate(start, last, std::size_t{1}, std::multiplies<>{});
}
template<class Range>
template <class Range>
static auto elements(const Range& r)
{
return elements(r.begin(), r.end());
}
common_dims common_dims::compute(const std::vector<std::size_t>& dims1, const std::vector<std::size_t>& dims2)
common_dims common_dims::compute(const std::vector<std::size_t>& dims1,
const std::vector<std::size_t>& dims2)
{
assert(elements(dims1) == elements(dims2));
common_dims cd;
auto it1 = dims1.begin();
auto it2 = dims2.begin();
auto it1 = dims1.begin();
auto it2 = dims2.begin();
std::size_t rem1 = 1;
std::size_t rem2 = 1;
while(it1 != dims1.end() and it2 != dims2.end())
{
auto d1 = *it1;
auto d2 = *it2;
if (d1 == d2)
if(d1 == d2)
{
cd.axes_map1.push_back({cd.dims.size()});
cd.axes_map2.push_back({cd.dims.size()});
......@@ -51,15 +52,15 @@ common_dims common_dims::compute(const std::vector<std::size_t>& dims1, const st
it1++;
it2++;
}
else if (d1 < d2)
else if(d1 < d2)
{
auto dim_end = compute_end_dim(it1, dims1.begin(), d2);
auto dims = range(it1, dim_end);
auto n = elements(dims);
if (n != d2)
auto dims = range(it1, dim_end);
auto n = elements(dims);
if(n != d2)
{
// If not divisible then we can't compute a common dims
if ((d2 % n) != 0)
if((d2 % n) != 0)
return {};
rem1 = d2 / n;
}
......@@ -69,7 +70,7 @@ common_dims common_dims::compute(const std::vector<std::size_t>& dims1, const st
cd.axes_map2.push_back(axes);
cd.dims.insert(cd.dims.end(), dims.begin(), dims.end());
if (rem1 != 1)
if(rem1 != 1)
cd.dims.push_back(rem1);
it1 += distance(dims);
it2++;
......
......@@ -10,14 +10,13 @@ inline namespace MIGRAPHX_INLINE_NS {
struct common_dims
{
static common_dims compute(const std::vector<std::size_t>& dims1, const std::vector<std::size_t>& dims2);
static common_dims compute(const std::vector<std::size_t>& dims1,
const std::vector<std::size_t>& dims2);
std::vector<std::size_t> dims;
std::vector<std::vector<std::size_t>> axes_map1;
std::vector<std::vector<std::size_t>> axes_map2;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
......@@ -835,16 +835,15 @@ inline auto has_attribute(const std::string& name)
[=](instruction_ref ins) { return ins->get_operator().attributes().contains(name); });
}
template<class T>
template <class T>
inline auto has_attribute(const std::string& name, T value)
{
return make_basic_pred_matcher(
[=](instruction_ref ins) {
auto attributes = ins->get_operator().attributes();
if (not attributes.contains(name))
return false;
return attributes[name].to<T>() == value;
});
return make_basic_pred_matcher([=](instruction_ref ins) {
auto attributes = ins->get_operator().attributes();
if(not attributes.contains(name))
return false;
return attributes[name].to<T>() == value;
});
}
template <class... Ms>
......
......@@ -917,10 +917,11 @@ struct find_poinwise_reduce_reshape
{
auto matcher() const
{
auto reshaper = match::name({"reshape", "squeeze", "unsqueeze"});
auto skip_contiguous = match::skip(match::name("contiguous"));
auto reshaper = match::name({"reshape", "squeeze", "unsqueeze"});
auto skip_contiguous = match::skip(match::name("contiguous"));
auto pointwise_or_reduce = match::any_of(match::pointwise(), match::reduce());
auto reshape_pointwise_or_reduce = reshaper(skip_contiguous(pointwise_or_reduce.bind("x"))).bind("reshape");
auto reshape_pointwise_or_reduce =
reshaper(skip_contiguous(pointwise_or_reduce.bind("x"))).bind("reshape");
return pointwise_or_reduce(match::any_of[match::inputs()](reshape_pointwise_or_reduce));
}
......@@ -944,15 +945,15 @@ struct find_poinwise_reduce_reshape
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto reshape_ins = r.instructions["reshape"];
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto reshape_ins = r.instructions["reshape"];
auto dims1 = x_ins->get_shape().lens();
auto dims2 = reshape_ins->get_shape().lens();
std::vector<int64_t> axes;
if (x_ins->get_operator().attributes().get("reduce", false))
if(x_ins->get_operator().attributes().get("reduce", false))
{
axes = x_ins->get_operator().to_value()["axes"].to_vector<int64_t>();
}
......@@ -961,30 +962,31 @@ struct find_poinwise_reduce_reshape
// Collect from inputs
fix([&](auto self, instruction_ref i) {
inss.insert(i);
entry = i;
entry = i;
auto pointwise_or_reduce = [&](instruction_ref input) {
if (input->can_eval())
if(input->can_eval())
return false;
return is_pointwise(input);
};
auto it = std::find_if(i->inputs().begin(), i->inputs().end(), pointwise_or_reduce);
if (it == i->inputs().end())
if(it == i->inputs().end())
return;
auto it2 = std::find_if(it, i->inputs().end(), pointwise_or_reduce);
// If there is more than one pointwise_reduce than stop
if (it2 != i->inputs().end())
if(it2 != i->inputs().end())
return;
self(*it);
})(x_ins);
// Collect from output
fix([&](auto self, instruction_ref out) {
for(auto output:out->outputs())
for(auto output : out->outputs())
{
if (not std::all_of(output->inputs().begin(), output->inputs().end(), [&](auto input) {
return input->can_eval() or contains(inss, input);
}))
if(not std::all_of(
output->inputs().begin(), output->inputs().end(), [&](auto input) {
return input->can_eval() or contains(inss, input);
}))
continue;
if (not is_pointwise_or_reduce(ins))
if(not is_pointwise_or_reduce(ins))
continue;
inss.insert(output);
self(output);
......@@ -996,9 +998,9 @@ struct find_poinwise_reduce_reshape
// Topological sort
fix([&](auto self, instruction_ref i) {
instructions.push_back(i);
for(auto output:i->outputs())
for(auto output : i->outputs())
{
if (not contains(inss, output))
if(not contains(inss, output))
{
aux.insert(output);
continue;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment