Commit 0b840acc authored by Paul's avatar Paul
Browse files

Format

parent 0a7ee4de
...@@ -745,19 +745,20 @@ void move_instructions_back(module& m, instruction_ref pos, std::vector<instruct ...@@ -745,19 +745,20 @@ void move_instructions_back(module& m, instruction_ref pos, std::vector<instruct
optional<std::size_t> find_split_axis(const std::vector<instruction_ref>& slices) optional<std::size_t> find_split_axis(const std::vector<instruction_ref>& slices)
{ {
auto first = slices.front(); auto first = slices.front();
auto get_slice = [](auto& i) -> auto& { return any_cast<op::slice>(i->get_operator()); }; auto get_slice = [](auto& i) -> auto& { return any_cast<op::slice>(i->get_operator()); };
auto get_start = [&](auto& i) -> auto& { return get_slice(i).starts; }; auto get_start = [&](auto& i) -> auto& { return get_slice(i).starts; };
auto get_end = [&](auto& i) -> auto& { return get_slice(i).ends; }; auto get_end = [&](auto& i) -> auto& { return get_slice(i).ends; };
auto find_different_axis = [&](auto select) { auto find_different_axis = [&](auto select) {
std::vector<int64_t> different; std::vector<int64_t> different;
std::for_each(slices.begin()+1, slices.end(), [&](const auto& slice) { std::for_each(slices.begin() + 1, slices.end(), [&](const auto& slice) {
auto it = select(slice).begin(); auto it = select(slice).begin();
while(it != select(slice).end()) while(it != select(slice).end())
{ {
auto p = std::mismatch(it, select(slice).end(), select(first).begin(), select(first).end()); auto p = std::mismatch(
it, select(slice).end(), select(first).begin(), select(first).end());
auto i = p.first - select(slice).begin(); auto i = p.first - select(slice).begin();
if (not contains(different, i)) if(not contains(different, i))
different.push_back(i); different.push_back(i);
it = p.first; it = p.first;
} }
...@@ -765,10 +766,10 @@ optional<std::size_t> find_split_axis(const std::vector<instruction_ref>& slices ...@@ -765,10 +766,10 @@ optional<std::size_t> find_split_axis(const std::vector<instruction_ref>& slices
return different; return different;
}; };
auto different_starts = find_different_axis(get_start); auto different_starts = find_different_axis(get_start);
auto different_ends = find_different_axis(get_end); auto different_ends = find_different_axis(get_end);
if (different_ends != different_starts) if(different_ends != different_starts)
return nullopt; return nullopt;
if (different_starts.empty()) if(different_starts.empty())
return nullopt; return nullopt;
return different_starts.front(); return different_starts.front();
} }
...@@ -810,9 +811,9 @@ std::vector<instruction_ref> get_splits(instruction_ref ins) ...@@ -810,9 +811,9 @@ std::vector<instruction_ref> get_splits(instruction_ref ins)
struct split_analyzer struct split_analyzer
{ {
std::vector<instruction_ref> slices = {}; std::vector<instruction_ref> slices = {};
std::size_t axis = 0; std::size_t axis = 0;
template<class T> template <class T>
static auto& get_slice(T& i) static auto& get_slice(T& i)
{ {
return any_cast<op::slice>(i->get_operator()); return any_cast<op::slice>(i->get_operator());
...@@ -823,26 +824,29 @@ struct split_analyzer ...@@ -823,26 +824,29 @@ struct split_analyzer
{ {
std::vector<instruction_ref> result; std::vector<instruction_ref> result;
std::copy_if(ins->outputs().begin(), std::copy_if(ins->outputs().begin(),
ins->outputs().end(), ins->outputs().end(),
std::back_inserter(result), std::back_inserter(result),
[&](auto i) { return i->name() == "slice"; }); [&](auto i) { return i->name() == "slice"; });
if(result.size() < 2) if(result.size() < 2)
return; return;
auto&& axes = get_slice(result.front()).axes; auto&& axes = get_slice(result.front()).axes;
if(std::any_of(result.begin(), result.end(), [&](auto i) { return get_slice(i).axes != axes; })) if(std::any_of(
result.begin(), result.end(), [&](auto i) { return get_slice(i).axes != axes; }))
return; return;
auto split_axis = find_split_axis(result); auto split_axis = find_split_axis(result);
if (not split_axis.has_value()) if(not split_axis.has_value())
return; return;
axis = *split_axis; axis = *split_axis;
auto get_start = [&](auto& i) -> auto& { return get_slice(i).starts[axis]; }; auto get_start = [&](auto& i) -> auto& { return get_slice(i).starts[axis]; };
auto get_end = [&](auto& i) -> auto& { return get_slice(i).ends[axis]; }; auto get_end = [&](auto& i) -> auto& { return get_slice(i).ends[axis]; };
std::sort( std::sort(result.begin(), result.end(), [&](auto x, auto y) {
result.begin(), result.end(), [&](auto x, auto y) { return get_start(x) < get_start(y); }); return get_start(x) < get_start(y);
if (get_start(result.front()) != 0) });
if(get_start(result.front()) != 0)
return; return;
auto it = std::adjacent_find( auto it = std::adjacent_find(result.begin(), result.end(), [&](auto x, auto y) {
result.begin(), result.end(), [&](auto x, auto y) { return get_end(x) != get_start(y); }); return get_end(x) != get_start(y);
});
if(it != result.end()) if(it != result.end())
return; return;
for(std::size_t i = 0; i < axes.size(); i++) for(std::size_t i = 0; i < axes.size(); i++)
...@@ -853,18 +857,12 @@ struct split_analyzer ...@@ -853,18 +857,12 @@ struct split_analyzer
slices = result; slices = result;
} }
bool has_multi_axes() const bool has_multi_axes() const { return get_slice(slices.front()).axes.size() > 1; }
{
return get_slice(slices.front()).axes.size() > 1;
}
operation pre_split() const operation pre_split() const
{ {
auto slice = get_slice(slices.front()); auto slice = get_slice(slices.front());
auto remove_axis = [&](auto& v) auto remove_axis = [&](auto& v) { v.erase(v.begin() + axis); };
{
v.erase(v.begin()+axis);
};
remove_axis(slice.axes); remove_axis(slice.axes);
remove_axis(slice.starts); remove_axis(slice.starts);
remove_axis(slice.ends); remove_axis(slice.ends);
...@@ -873,17 +871,20 @@ struct split_analyzer ...@@ -873,17 +871,20 @@ struct split_analyzer
instruction_ref insert_pre_split(module& m, instruction_ref ins) const instruction_ref insert_pre_split(module& m, instruction_ref ins) const
{ {
if (not has_multi_axes()) if(not has_multi_axes())
return ins; return ins;
return m.insert_instruction(std::next(ins), pre_split(), ins); return m.insert_instruction(std::next(ins), pre_split(), ins);
} }
operation post_split(instruction_ref ins) const operation post_split(instruction_ref ins) const
{ {
if (not has_multi_axes()) if(not has_multi_axes())
return ins->get_operator(); return ins->get_operator();
auto slice = get_slice(ins); auto slice = get_slice(ins);
return make_op("slice", {{"axes", {slice.axes[axis]}}, {"starts", {slice.starts[axis]}}, {"ends", {slice.ends[axis]}}}); return make_op("slice",
{{"axes", {slice.axes[axis]}},
{"starts", {slice.starts[axis]}},
{"ends", {slice.ends[axis]}}});
} }
}; };
...@@ -981,7 +982,7 @@ struct find_splits ...@@ -981,7 +982,7 @@ struct find_splits
{ {
auto ins = r.result; auto ins = r.result;
split_analyzer analyzer{ins}; split_analyzer analyzer{ins};
if (analyzer.slices.empty()) if(analyzer.slices.empty())
return; return;
ins = analyzer.insert_pre_split(m, ins); ins = analyzer.insert_pre_split(m, ins);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment