Commit 69de5fdb authored by Paul's avatar Paul
Browse files

Add common_dims

parent 59386637
......@@ -34,6 +34,7 @@ add_library(migraphx
argument.cpp
auto_contiguous.cpp
common.cpp
common_dims.cpp
compile_src.cpp
convert_to_json.cpp
cpp_generator.cpp
......
#include <migraphx/common_dims.hpp>
#include <migraphx/ranges.hpp>
#include <algorithm>
#include <cassert>
#include <numeric>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Iterator>
static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
{
std::size_t x = 1;
auto it = std::find_if(start, last, [&](auto i) {
x *= i;
return x >= dim;
});
if(x != dim)
return start;
return it;
}
template<class Iterator>
static auto elements(Iterator start, Iterator last)
{
return std::accumulate(start, last, std::size_t{1}, std::multiplies<>{});
}
template<class Range>
static auto elements(const Range& r)
{
return elements(r.begin(), r.end());
}
common_dims common_dims::compute(const std::vector<std::size_t>& dims1, const std::vector<std::size_t>& dims2)
{
assert(elements(dims1) == elements(dims2));
common_dims cd;
auto it1 = dims1.begin();
auto it2 = dims2.begin();
std::size_t rem1 = 1;
std::size_t rem2 = 1;
while(it1 != dims1.end() and it2 != dims2.end())
{
auto d1 = *it1;
auto d2 = *it2;
if (d1 == d2)
{
cd.axes_map1.push_back({cd.dims.size()});
cd.axes_map2.push_back({cd.dims.size()});
cd.dims.push_back(d1);
it1++;
it2++;
}
else if (d1 < d2)
{
auto dim_end = compute_end_dim(it1, dims1.begin(), d2);
auto dims = range(it1, dim_end);
auto n = elements(dims);
if (n != d2)
{
// If not divisible then we can't compute a common dims
if ((d2 % n) != 0)
return {};
rem1 = d2 / n;
}
std::vector<std::size_t> axes(distance(dims));
std::iota(axes.begin(), axes.end(), cd.dims.size());
cd.axes_map1.push_back(axes);
cd.axes_map2.push_back(axes);
cd.dims.insert(cd.dims.end(), dims.begin(), dims.end());
if (rem1 != 1)
cd.dims.push_back(rem1);
it1 += distance(dims);
it2++;
}
}
return cd;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
#include <migraphx/config.hpp>
#include <cstdint>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct common_dims
{
static common_dims compute(const std::vector<std::size_t>& dims1, const std::vector<std::size_t>& dims2);
std::vector<std::size_t> dims;
std::vector<std::vector<std::size_t>> axes_map1;
std::vector<std::vector<std::size_t>> axes_map2;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
......@@ -474,7 +474,7 @@ struct match_fold_f
template <class... Ts>
auto operator()(Ts... ms) const
{
return make_bf_matcher(
return make_basic_fun_matcher(
[=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
bool matches = match_fold_f::fold_matchers(ctx, ins, ms...);
if(matches == Matches)
......@@ -489,7 +489,7 @@ struct match_fold_f
return [=](auto... ms) {
// Workaround ICE on gcc by packing matchers into an object
auto mpack = pack(ms...);
return make_bf_matcher(
return make_basic_fun_matcher(
[=](matcher_context& ctx, instruction_ref start) -> optional<instruction_ref> {
Op op;
bool matches = Start;
......@@ -835,10 +835,28 @@ inline auto has_attribute(const std::string& name)
[=](instruction_ref ins) { return ins->get_operator().attributes().contains(name); });
}
template<class T>
inline auto has_attribute(const std::string& name, T value)
{
return make_basic_pred_matcher(
[=](instruction_ref ins) {
auto attributes = ins->get_operator().attributes();
if (not attributes.contains(name))
return false;
return attributes[name].to<T>() == value;
});
}
template <class... Ms>
auto pointwise(Ms... ms)
{
return match::has_attribute("pointwise")(ms...);
return match::has_attribute("pointwise", true)(ms...);
}
template <class... Ms>
auto reduce(Ms... ms)
{
return match::has_attribute("reduce", true)(ms...);
}
} // namespace match
......
......@@ -913,6 +913,102 @@ struct find_broadcast_reshaper
}
};
struct find_poinwise_reduce_reshape
{
auto matcher() const
{
auto reshaper = match::name({"reshape", "squeeze", "unsqueeze"});
auto skip_contiguous = match::skip(match::name("contiguous"));
auto pointwise_or_reduce = match::any_of(match::pointwise(), match::reduce());
auto reshape_pointwise_or_reduce = reshaper(skip_contiguous(pointwise_or_reduce.bind("x"))).bind("reshape");
return pointwise_or_reduce(match::any_of[match::inputs()](reshape_pointwise_or_reduce));
}
static bool is_pointwise(instruction_ref ins)
{
auto a = ins->get_operator().attributes();
return a.get("pointwise", false);
}
static bool is_reduce(instruction_ref ins)
{
auto a = ins->get_operator().attributes();
return a.get("reduce", false);
}
static bool is_pointwise_or_reduce(instruction_ref ins)
{
auto a = ins->get_operator().attributes();
return a.get("pointwise", false) or a.get("reduce", false);
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto reshape_ins = r.instructions["reshape"];
auto dims1 = x_ins->get_shape().lens();
auto dims2 = reshape_ins->get_shape().lens();
std::vector<int64_t> axes;
if (x_ins->get_operator().attributes().get("reduce", false))
{
axes = x_ins->get_operator().to_value()["axes"].to_vector<int64_t>();
}
std::unordered_set<instruction_ref> inss;
instruction_ref entry;
// Collect from inputs
fix([&](auto self, instruction_ref i) {
inss.insert(i);
entry = i;
auto pointwise_or_reduce = [&](instruction_ref input) {
if (input->can_eval())
return false;
return is_pointwise(input);
};
auto it = std::find_if(i->inputs().begin(), i->inputs().end(), pointwise_or_reduce);
if (it == i->inputs().end())
return;
auto it2 = std::find_if(it, i->inputs().end(), pointwise_or_reduce);
// If there is more than one pointwise_reduce than stop
if (it2 != i->inputs().end())
return;
self(*it);
})(x_ins);
// Collect from output
fix([&](auto self, instruction_ref out) {
for(auto output:out->outputs())
{
if (not std::all_of(output->inputs().begin(), output->inputs().end(), [&](auto input) {
return input->can_eval() or contains(inss, input);
}))
continue;
if (not is_pointwise_or_reduce(ins))
continue;
inss.insert(output);
self(output);
}
})(x_ins);
std::vector<instruction_ref> instructions;
std::unordered_set<instruction_ref> aux;
// Topological sort
fix([&](auto self, instruction_ref i) {
instructions.push_back(i);
for(auto output:i->outputs())
{
if (not contains(inss, output))
{
aux.insert(output);
continue;
}
self(output);
}
})(entry);
}
};
void simplify_reshapes::apply(module& m) const
{
for(int i = 0; i < 4; i++)
......
#include <migraphx/common_dims.hpp>
#include <test.hpp>
TEST_CASE(common1)
{
auto cd = migraphx::common_dims::compute({2, 32, 2560}, {2, 1280, 8, 8});
EXPECT(cd.dims == std::vector<std::size_t>{2, 32, 40, 8, 8});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment