Commit 0a7ee4de authored by Paul's avatar Paul
Browse files

handle multi axis split

parent 3c160a3f
......@@ -743,6 +743,36 @@ void move_instructions_back(module& m, instruction_ref pos, std::vector<instruct
}
}
optional<std::size_t> find_split_axis(const std::vector<instruction_ref>& slices)
{
auto first = slices.front();
auto get_slice = [](auto& i) -> auto& { return any_cast<op::slice>(i->get_operator()); };
auto get_start = [&](auto& i) -> auto& { return get_slice(i).starts; };
auto get_end = [&](auto& i) -> auto& { return get_slice(i).ends; };
auto find_different_axis = [&](auto select) {
std::vector<int64_t> different;
std::for_each(slices.begin()+1, slices.end(), [&](const auto& slice) {
auto it = select(slice).begin();
while(it != select(slice).end())
{
auto p = std::mismatch(it, select(slice).end(), select(first).begin(), select(first).end());
auto i = p.first - select(slice).begin();
if (not contains(different, i))
different.push_back(i);
it = p.first;
}
});
return different;
};
auto different_starts = find_different_axis(get_start);
auto different_ends = find_different_axis(get_end);
if (different_ends != different_starts)
return nullopt;
if (different_starts.empty())
return nullopt;
return different_starts.front();
}
std::vector<instruction_ref> get_splits(instruction_ref ins)
{
std::vector<instruction_ref> result;
......@@ -777,6 +807,86 @@ std::vector<instruction_ref> get_splits(instruction_ref ins)
return result;
}
struct split_analyzer
{
std::vector<instruction_ref> slices = {};
std::size_t axis = 0;
template<class T>
static auto& get_slice(T& i)
{
return any_cast<op::slice>(i->get_operator());
}
split_analyzer() = default;
explicit split_analyzer(instruction_ref ins)
{
std::vector<instruction_ref> result;
std::copy_if(ins->outputs().begin(),
ins->outputs().end(),
std::back_inserter(result),
[&](auto i) { return i->name() == "slice"; });
if(result.size() < 2)
return;
auto&& axes = get_slice(result.front()).axes;
if(std::any_of(result.begin(), result.end(), [&](auto i) { return get_slice(i).axes != axes; }))
return;
auto split_axis = find_split_axis(result);
if (not split_axis.has_value())
return;
axis = *split_axis;
auto get_start = [&](auto& i) -> auto& { return get_slice(i).starts[axis]; };
auto get_end = [&](auto& i) -> auto& { return get_slice(i).ends[axis]; };
std::sort(
result.begin(), result.end(), [&](auto x, auto y) { return get_start(x) < get_start(y); });
if (get_start(result.front()) != 0)
return;
auto it = std::adjacent_find(
result.begin(), result.end(), [&](auto x, auto y) { return get_end(x) != get_start(y); });
if(it != result.end())
return;
for(std::size_t i = 0; i < axes.size(); i++)
{
if(ins->get_shape().lens()[axes[i]] != get_slice(result.back()).ends[i])
return;
}
slices = result;
}
bool has_multi_axes() const
{
return get_slice(slices.front()).axes.size() > 1;
}
operation pre_split() const
{
auto slice = get_slice(slices.front());
auto remove_axis = [&](auto& v)
{
v.erase(v.begin()+axis);
};
remove_axis(slice.axes);
remove_axis(slice.starts);
remove_axis(slice.ends);
return slice;
}
instruction_ref insert_pre_split(module& m, instruction_ref ins) const
{
if (not has_multi_axes())
return ins;
return m.insert_instruction(std::next(ins), pre_split(), ins);
}
operation post_split(instruction_ref ins) const
{
if (not has_multi_axes())
return ins->get_operator();
auto slice = get_slice(ins);
return make_op("slice", {{"axes", {slice.axes[axis]}}, {"starts", {slice.starts[axis]}}, {"ends", {slice.ends[axis]}}});
}
};
struct find_splits
{
auto matcher() const
......@@ -870,14 +980,16 @@ struct find_splits
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto splits = get_splits(ins);
if(splits.empty())
split_analyzer analyzer{ins};
if (analyzer.slices.empty())
return;
for(const auto& group : get_split_groups(m, splits))
ins = analyzer.insert_pre_split(m, ins);
for(const auto& group : get_split_groups(m, analyzer.slices))
{
auto start = group.front();
auto split_front = splits.front();
auto split_front = analyzer.slices.front();
auto op = start->get_operator();
if(not is_fusable(start, split_front))
{
......@@ -920,7 +1032,7 @@ struct find_splits
move_instructions_back(m, ins, data_args);
auto slice_op = any_cast<op::slice>(splits.front()->get_operator());
auto slice_op = any_cast<op::slice>(analyzer.slices.front()->get_operator());
assert(not slice_op.axes.empty());
if(slice_op.axes.size() > 1)
return;
......@@ -951,7 +1063,7 @@ struct find_splits
m.replace_instruction(output, output->get_operator(), x);
}
m.replace_instruction(i, split->get_operator(), c);
m.replace_instruction(i, analyzer.post_split(split), c);
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment