Commit 0b840acc authored by Paul's avatar Paul
Browse files

Format

parent 0a7ee4de
......@@ -751,13 +751,14 @@ optional<std::size_t> find_split_axis(const std::vector<instruction_ref>& slices
auto get_end = [&](auto& i) -> auto& { return get_slice(i).ends; };
auto find_different_axis = [&](auto select) {
std::vector<int64_t> different;
std::for_each(slices.begin()+1, slices.end(), [&](const auto& slice) {
std::for_each(slices.begin() + 1, slices.end(), [&](const auto& slice) {
auto it = select(slice).begin();
while(it != select(slice).end())
{
auto p = std::mismatch(it, select(slice).end(), select(first).begin(), select(first).end());
auto p = std::mismatch(
it, select(slice).end(), select(first).begin(), select(first).end());
auto i = p.first - select(slice).begin();
if (not contains(different, i))
if(not contains(different, i))
different.push_back(i);
it = p.first;
}
......@@ -766,9 +767,9 @@ optional<std::size_t> find_split_axis(const std::vector<instruction_ref>& slices
};
auto different_starts = find_different_axis(get_start);
auto different_ends = find_different_axis(get_end);
if (different_ends != different_starts)
if(different_ends != different_starts)
return nullopt;
if (different_starts.empty())
if(different_starts.empty())
return nullopt;
return different_starts.front();
}
......@@ -812,7 +813,7 @@ struct split_analyzer
std::vector<instruction_ref> slices = {};
std::size_t axis = 0;
template<class T>
template <class T>
static auto& get_slice(T& i)
{
return any_cast<op::slice>(i->get_operator());
......@@ -829,20 +830,23 @@ struct split_analyzer
if(result.size() < 2)
return;
auto&& axes = get_slice(result.front()).axes;
if(std::any_of(result.begin(), result.end(), [&](auto i) { return get_slice(i).axes != axes; }))
if(std::any_of(
result.begin(), result.end(), [&](auto i) { return get_slice(i).axes != axes; }))
return;
auto split_axis = find_split_axis(result);
if (not split_axis.has_value())
if(not split_axis.has_value())
return;
axis = *split_axis;
auto get_start = [&](auto& i) -> auto& { return get_slice(i).starts[axis]; };
auto get_end = [&](auto& i) -> auto& { return get_slice(i).ends[axis]; };
std::sort(
result.begin(), result.end(), [&](auto x, auto y) { return get_start(x) < get_start(y); });
if (get_start(result.front()) != 0)
std::sort(result.begin(), result.end(), [&](auto x, auto y) {
return get_start(x) < get_start(y);
});
if(get_start(result.front()) != 0)
return;
auto it = std::adjacent_find(
result.begin(), result.end(), [&](auto x, auto y) { return get_end(x) != get_start(y); });
auto it = std::adjacent_find(result.begin(), result.end(), [&](auto x, auto y) {
return get_end(x) != get_start(y);
});
if(it != result.end())
return;
for(std::size_t i = 0; i < axes.size(); i++)
......@@ -853,18 +857,12 @@ struct split_analyzer
slices = result;
}
bool has_multi_axes() const
{
return get_slice(slices.front()).axes.size() > 1;
}
bool has_multi_axes() const { return get_slice(slices.front()).axes.size() > 1; }
operation pre_split() const
{
auto slice = get_slice(slices.front());
auto remove_axis = [&](auto& v)
{
v.erase(v.begin()+axis);
};
auto remove_axis = [&](auto& v) { v.erase(v.begin() + axis); };
remove_axis(slice.axes);
remove_axis(slice.starts);
remove_axis(slice.ends);
......@@ -873,17 +871,20 @@ struct split_analyzer
instruction_ref insert_pre_split(module& m, instruction_ref ins) const
{
if (not has_multi_axes())
if(not has_multi_axes())
return ins;
return m.insert_instruction(std::next(ins), pre_split(), ins);
}
operation post_split(instruction_ref ins) const
{
if (not has_multi_axes())
if(not has_multi_axes())
return ins->get_operator();
auto slice = get_slice(ins);
return make_op("slice", {{"axes", {slice.axes[axis]}}, {"starts", {slice.starts[axis]}}, {"ends", {slice.ends[axis]}}});
return make_op("slice",
{{"axes", {slice.axes[axis]}},
{"starts", {slice.starts[axis]}},
{"ends", {slice.ends[axis]}}});
}
};
......@@ -981,7 +982,7 @@ struct find_splits
{
auto ins = r.result;
split_analyzer analyzer{ins};
if (analyzer.slices.empty())
if(analyzer.slices.empty())
return;
ins = analyzer.insert_pre_split(m, ins);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment