Unverified Commit 722c2a8c authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Consider buffer data type into indices provably disjoint analysis (#664)

parent a16f0cf5
......@@ -273,20 +273,29 @@ private:
}
for (size_t i = 0; i < prev.buffer_indices.size(); i++) {
auto prev_dtype = prev.dtype;
auto curr_dtype = curr.dtype;
const auto &prev_indice = prev.buffer_indices[i];
const auto &curr_indice = curr.buffer_indices[i];
if (!ExprDeepEqual()(prev_indice, curr_indice)) {
auto prev_indice_bytes =
analyzer_.Simplify(prev_indice * prev_dtype.bytes());
auto curr_indice_bytes =
analyzer_.Simplify(curr_indice * curr_dtype.bytes());
has_same_index = false;
// If both are const, we can check if they are disjoint
// by checking if the bounds are disjoint
// [1024, 2048], [2048, 3072] are disjoint
// [1024, 2048], [1024, 1024] are not disjoint
auto prev_bound = analyzer_.const_int_bound(prev_indice);
auto curr_bound = analyzer_.const_int_bound(curr_indice);
auto prev_bound = analyzer_.const_int_bound(prev_indice_bytes);
auto curr_bound = analyzer_.const_int_bound(curr_indice_bytes);
if (prev_bound.defined() && curr_bound.defined()) {
if (prev_bound->min_value > curr_bound->max_value ||
curr_bound->min_value > prev_bound->max_value) {
if ((prev_bound->min_value) > (curr_bound->max_value) ||
(curr_bound->min_value) > (prev_bound->max_value)) {
range_is_overlap = false;
break;
}
......@@ -294,17 +303,18 @@ private:
// if we can prove prev_indice < curr_indice or prev_indice >
// curr_indice, then they are not overlap
auto prev_dtype = prev_indice.dtype();
auto curr_dtype = curr_indice.dtype();
if (prev_dtype.lanes() != curr_dtype.lanes()) {
auto prev_indices_dtype = prev_indice.dtype();
auto curr_indices_dtype = curr_indice.dtype();
if (prev_indices_dtype.lanes() != curr_indices_dtype.lanes()) {
// can not support different lanes binary op like <, >, <=, >=
// skip otherwise it will lead to error
continue;
}
bool provably_disjoint =
analyzer_.CanProve(prev_indice < curr_indice,
analyzer_.CanProve(prev_indice_bytes < curr_indice_bytes,
arith::ProofStrength::kSymbolicBound) ||
analyzer_.CanProve(prev_indice > curr_indice,
analyzer_.CanProve(prev_indice_bytes > curr_indice_bytes,
arith::ProofStrength::kSymbolicBound);
if (provably_disjoint) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment