Commit 665455e1 authored by Paul's avatar Paul
Browse files

Format

parent 69de5fdb
...@@ -20,30 +20,31 @@ static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim) ...@@ -20,30 +20,31 @@ static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
return it; return it;
} }
template<class Iterator> template <class Iterator>
static auto elements(Iterator start, Iterator last) static auto elements(Iterator start, Iterator last)
{ {
return std::accumulate(start, last, std::size_t{1}, std::multiplies<>{}); return std::accumulate(start, last, std::size_t{1}, std::multiplies<>{});
} }
template<class Range> template <class Range>
static auto elements(const Range& r) static auto elements(const Range& r)
{ {
return elements(r.begin(), r.end()); return elements(r.begin(), r.end());
} }
common_dims common_dims::compute(const std::vector<std::size_t>& dims1, const std::vector<std::size_t>& dims2) common_dims common_dims::compute(const std::vector<std::size_t>& dims1,
const std::vector<std::size_t>& dims2)
{ {
assert(elements(dims1) == elements(dims2)); assert(elements(dims1) == elements(dims2));
common_dims cd; common_dims cd;
auto it1 = dims1.begin(); auto it1 = dims1.begin();
auto it2 = dims2.begin(); auto it2 = dims2.begin();
std::size_t rem1 = 1; std::size_t rem1 = 1;
std::size_t rem2 = 1; std::size_t rem2 = 1;
while(it1 != dims1.end() and it2 != dims2.end()) while(it1 != dims1.end() and it2 != dims2.end())
{ {
auto d1 = *it1; auto d1 = *it1;
auto d2 = *it2; auto d2 = *it2;
if (d1 == d2) if(d1 == d2)
{ {
cd.axes_map1.push_back({cd.dims.size()}); cd.axes_map1.push_back({cd.dims.size()});
cd.axes_map2.push_back({cd.dims.size()}); cd.axes_map2.push_back({cd.dims.size()});
...@@ -51,15 +52,15 @@ common_dims common_dims::compute(const std::vector<std::size_t>& dims1, const st ...@@ -51,15 +52,15 @@ common_dims common_dims::compute(const std::vector<std::size_t>& dims1, const st
it1++; it1++;
it2++; it2++;
} }
else if (d1 < d2) else if(d1 < d2)
{ {
auto dim_end = compute_end_dim(it1, dims1.begin(), d2); auto dim_end = compute_end_dim(it1, dims1.begin(), d2);
auto dims = range(it1, dim_end); auto dims = range(it1, dim_end);
auto n = elements(dims); auto n = elements(dims);
if (n != d2) if(n != d2)
{ {
// If not divisible then we can't compute a common dims // If not divisible then we can't compute a common dims
if ((d2 % n) != 0) if((d2 % n) != 0)
return {}; return {};
rem1 = d2 / n; rem1 = d2 / n;
} }
...@@ -69,7 +70,7 @@ common_dims common_dims::compute(const std::vector<std::size_t>& dims1, const st ...@@ -69,7 +70,7 @@ common_dims common_dims::compute(const std::vector<std::size_t>& dims1, const st
cd.axes_map2.push_back(axes); cd.axes_map2.push_back(axes);
cd.dims.insert(cd.dims.end(), dims.begin(), dims.end()); cd.dims.insert(cd.dims.end(), dims.begin(), dims.end());
if (rem1 != 1) if(rem1 != 1)
cd.dims.push_back(rem1); cd.dims.push_back(rem1);
it1 += distance(dims); it1 += distance(dims);
it2++; it2++;
......
...@@ -10,14 +10,13 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,14 +10,13 @@ inline namespace MIGRAPHX_INLINE_NS {
struct common_dims struct common_dims
{ {
static common_dims compute(const std::vector<std::size_t>& dims1, const std::vector<std::size_t>& dims2); static common_dims compute(const std::vector<std::size_t>& dims1,
const std::vector<std::size_t>& dims2);
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
std::vector<std::vector<std::size_t>> axes_map1; std::vector<std::vector<std::size_t>> axes_map1;
std::vector<std::vector<std::size_t>> axes_map2; std::vector<std::vector<std::size_t>> axes_map2;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP #endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
...@@ -835,16 +835,15 @@ inline auto has_attribute(const std::string& name) ...@@ -835,16 +835,15 @@ inline auto has_attribute(const std::string& name)
[=](instruction_ref ins) { return ins->get_operator().attributes().contains(name); }); [=](instruction_ref ins) { return ins->get_operator().attributes().contains(name); });
} }
template<class T> template <class T>
inline auto has_attribute(const std::string& name, T value) inline auto has_attribute(const std::string& name, T value)
{ {
return make_basic_pred_matcher( return make_basic_pred_matcher([=](instruction_ref ins) {
[=](instruction_ref ins) { auto attributes = ins->get_operator().attributes();
auto attributes = ins->get_operator().attributes(); if(not attributes.contains(name))
if (not attributes.contains(name)) return false;
return false; return attributes[name].to<T>() == value;
return attributes[name].to<T>() == value; });
});
} }
template <class... Ms> template <class... Ms>
......
...@@ -917,10 +917,11 @@ struct find_poinwise_reduce_reshape ...@@ -917,10 +917,11 @@ struct find_poinwise_reduce_reshape
{ {
auto matcher() const auto matcher() const
{ {
auto reshaper = match::name({"reshape", "squeeze", "unsqueeze"}); auto reshaper = match::name({"reshape", "squeeze", "unsqueeze"});
auto skip_contiguous = match::skip(match::name("contiguous")); auto skip_contiguous = match::skip(match::name("contiguous"));
auto pointwise_or_reduce = match::any_of(match::pointwise(), match::reduce()); auto pointwise_or_reduce = match::any_of(match::pointwise(), match::reduce());
auto reshape_pointwise_or_reduce = reshaper(skip_contiguous(pointwise_or_reduce.bind("x"))).bind("reshape"); auto reshape_pointwise_or_reduce =
reshaper(skip_contiguous(pointwise_or_reduce.bind("x"))).bind("reshape");
return pointwise_or_reduce(match::any_of[match::inputs()](reshape_pointwise_or_reduce)); return pointwise_or_reduce(match::any_of[match::inputs()](reshape_pointwise_or_reduce));
} }
...@@ -944,15 +945,15 @@ struct find_poinwise_reduce_reshape ...@@ -944,15 +945,15 @@ struct find_poinwise_reduce_reshape
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
auto reshape_ins = r.instructions["reshape"]; auto reshape_ins = r.instructions["reshape"];
auto dims1 = x_ins->get_shape().lens(); auto dims1 = x_ins->get_shape().lens();
auto dims2 = reshape_ins->get_shape().lens(); auto dims2 = reshape_ins->get_shape().lens();
std::vector<int64_t> axes; std::vector<int64_t> axes;
if (x_ins->get_operator().attributes().get("reduce", false)) if(x_ins->get_operator().attributes().get("reduce", false))
{ {
axes = x_ins->get_operator().to_value()["axes"].to_vector<int64_t>(); axes = x_ins->get_operator().to_value()["axes"].to_vector<int64_t>();
} }
...@@ -961,30 +962,31 @@ struct find_poinwise_reduce_reshape ...@@ -961,30 +962,31 @@ struct find_poinwise_reduce_reshape
// Collect from inputs // Collect from inputs
fix([&](auto self, instruction_ref i) { fix([&](auto self, instruction_ref i) {
inss.insert(i); inss.insert(i);
entry = i; entry = i;
auto pointwise_or_reduce = [&](instruction_ref input) { auto pointwise_or_reduce = [&](instruction_ref input) {
if (input->can_eval()) if(input->can_eval())
return false; return false;
return is_pointwise(input); return is_pointwise(input);
}; };
auto it = std::find_if(i->inputs().begin(), i->inputs().end(), pointwise_or_reduce); auto it = std::find_if(i->inputs().begin(), i->inputs().end(), pointwise_or_reduce);
if (it == i->inputs().end()) if(it == i->inputs().end())
return; return;
auto it2 = std::find_if(it, i->inputs().end(), pointwise_or_reduce); auto it2 = std::find_if(it, i->inputs().end(), pointwise_or_reduce);
// If there is more than one pointwise_reduce than stop // If there is more than one pointwise_reduce than stop
if (it2 != i->inputs().end()) if(it2 != i->inputs().end())
return; return;
self(*it); self(*it);
})(x_ins); })(x_ins);
// Collect from output // Collect from output
fix([&](auto self, instruction_ref out) { fix([&](auto self, instruction_ref out) {
for(auto output:out->outputs()) for(auto output : out->outputs())
{ {
if (not std::all_of(output->inputs().begin(), output->inputs().end(), [&](auto input) { if(not std::all_of(
return input->can_eval() or contains(inss, input); output->inputs().begin(), output->inputs().end(), [&](auto input) {
})) return input->can_eval() or contains(inss, input);
}))
continue; continue;
if (not is_pointwise_or_reduce(ins)) if(not is_pointwise_or_reduce(ins))
continue; continue;
inss.insert(output); inss.insert(output);
self(output); self(output);
...@@ -996,9 +998,9 @@ struct find_poinwise_reduce_reshape ...@@ -996,9 +998,9 @@ struct find_poinwise_reduce_reshape
// Topological sort // Topological sort
fix([&](auto self, instruction_ref i) { fix([&](auto self, instruction_ref i) {
instructions.push_back(i); instructions.push_back(i);
for(auto output:i->outputs()) for(auto output : i->outputs())
{ {
if (not contains(inss, output)) if(not contains(inss, output))
{ {
aux.insert(output); aux.insert(output);
continue; continue;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment