Commit 665455e1 authored by Paul's avatar Paul
Browse files

Format

parent 69de5fdb
......@@ -20,18 +20,19 @@ static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
return it;
}
template<class Iterator>
template <class Iterator>
static auto elements(Iterator start, Iterator last)
{
return std::accumulate(start, last, std::size_t{1}, std::multiplies<>{});
}
template<class Range>
template <class Range>
static auto elements(const Range& r)
{
return elements(r.begin(), r.end());
}
common_dims common_dims::compute(const std::vector<std::size_t>& dims1, const std::vector<std::size_t>& dims2)
common_dims common_dims::compute(const std::vector<std::size_t>& dims1,
const std::vector<std::size_t>& dims2)
{
assert(elements(dims1) == elements(dims2));
common_dims cd;
......@@ -43,7 +44,7 @@ common_dims common_dims::compute(const std::vector<std::size_t>& dims1, const st
{
auto d1 = *it1;
auto d2 = *it2;
if (d1 == d2)
if(d1 == d2)
{
cd.axes_map1.push_back({cd.dims.size()});
cd.axes_map2.push_back({cd.dims.size()});
......@@ -51,15 +52,15 @@ common_dims common_dims::compute(const std::vector<std::size_t>& dims1, const st
it1++;
it2++;
}
else if (d1 < d2)
else if(d1 < d2)
{
auto dim_end = compute_end_dim(it1, dims1.begin(), d2);
auto dims = range(it1, dim_end);
auto n = elements(dims);
if (n != d2)
if(n != d2)
{
// If not divisible then we can't compute a common dims
if ((d2 % n) != 0)
if((d2 % n) != 0)
return {};
rem1 = d2 / n;
}
......@@ -69,7 +70,7 @@ common_dims common_dims::compute(const std::vector<std::size_t>& dims1, const st
cd.axes_map2.push_back(axes);
cd.dims.insert(cd.dims.end(), dims.begin(), dims.end());
if (rem1 != 1)
if(rem1 != 1)
cd.dims.push_back(rem1);
it1 += distance(dims);
it2++;
......
......@@ -10,14 +10,13 @@ inline namespace MIGRAPHX_INLINE_NS {
struct common_dims
{
static common_dims compute(const std::vector<std::size_t>& dims1, const std::vector<std::size_t>& dims2);
static common_dims compute(const std::vector<std::size_t>& dims1,
const std::vector<std::size_t>& dims2);
std::vector<std::size_t> dims;
std::vector<std::vector<std::size_t>> axes_map1;
std::vector<std::vector<std::size_t>> axes_map2;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
......@@ -835,13 +835,12 @@ inline auto has_attribute(const std::string& name)
[=](instruction_ref ins) { return ins->get_operator().attributes().contains(name); });
}
template<class T>
template <class T>
inline auto has_attribute(const std::string& name, T value)
{
return make_basic_pred_matcher(
[=](instruction_ref ins) {
return make_basic_pred_matcher([=](instruction_ref ins) {
auto attributes = ins->get_operator().attributes();
if (not attributes.contains(name))
if(not attributes.contains(name))
return false;
return attributes[name].to<T>() == value;
});
......
......@@ -920,7 +920,8 @@ struct find_poinwise_reduce_reshape
auto reshaper = match::name({"reshape", "squeeze", "unsqueeze"});
auto skip_contiguous = match::skip(match::name("contiguous"));
auto pointwise_or_reduce = match::any_of(match::pointwise(), match::reduce());
auto reshape_pointwise_or_reduce = reshaper(skip_contiguous(pointwise_or_reduce.bind("x"))).bind("reshape");
auto reshape_pointwise_or_reduce =
reshaper(skip_contiguous(pointwise_or_reduce.bind("x"))).bind("reshape");
return pointwise_or_reduce(match::any_of[match::inputs()](reshape_pointwise_or_reduce));
}
......@@ -952,7 +953,7 @@ struct find_poinwise_reduce_reshape
auto dims2 = reshape_ins->get_shape().lens();
std::vector<int64_t> axes;
if (x_ins->get_operator().attributes().get("reduce", false))
if(x_ins->get_operator().attributes().get("reduce", false))
{
axes = x_ins->get_operator().to_value()["axes"].to_vector<int64_t>();
}
......@@ -963,28 +964,29 @@ struct find_poinwise_reduce_reshape
inss.insert(i);
entry = i;
auto pointwise_or_reduce = [&](instruction_ref input) {
if (input->can_eval())
if(input->can_eval())
return false;
return is_pointwise(input);
};
auto it = std::find_if(i->inputs().begin(), i->inputs().end(), pointwise_or_reduce);
if (it == i->inputs().end())
if(it == i->inputs().end())
return;
auto it2 = std::find_if(it, i->inputs().end(), pointwise_or_reduce);
// If there is more than one pointwise_reduce than stop
if (it2 != i->inputs().end())
if(it2 != i->inputs().end())
return;
self(*it);
})(x_ins);
// Collect from output
fix([&](auto self, instruction_ref out) {
for(auto output:out->outputs())
for(auto output : out->outputs())
{
if (not std::all_of(output->inputs().begin(), output->inputs().end(), [&](auto input) {
if(not std::all_of(
output->inputs().begin(), output->inputs().end(), [&](auto input) {
return input->can_eval() or contains(inss, input);
}))
continue;
if (not is_pointwise_or_reduce(ins))
if(not is_pointwise_or_reduce(ins))
continue;
inss.insert(output);
self(output);
......@@ -996,9 +998,9 @@ struct find_poinwise_reduce_reshape
// Topological sort
fix([&](auto self, instruction_ref i) {
instructions.push_back(i);
for(auto output:i->outputs())
for(auto output : i->outputs())
{
if (not contains(inss, output))
if(not contains(inss, output))
{
aux.insert(output);
continue;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment