Unverified Commit 54aaec98 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

Refactor index handling in BufferStore and BufferLoad to promote 64-bit integers (#796)

- Updated index processing in `BufferStore` and `BufferLoad` to ensure that integer indices with less than 64 bits are promoted to 64-bit integers.
- Introduced a new array to store the modified indices before updating the original indices, enhancing clarity and maintainability of the code.
parent 7467f2b3
...@@ -123,6 +123,7 @@ private: ...@@ -123,6 +123,7 @@ private:
auto buffer_store = auto buffer_store =
Downcast<BufferStore>(IRMutatorWithAnalyzer::VisitStmt_(op)); Downcast<BufferStore>(IRMutatorWithAnalyzer::VisitStmt_(op));
auto indices = buffer_store->indices; auto indices = buffer_store->indices;
Array<PrimExpr> new_indices;
for (auto index : indices) { for (auto index : indices) {
if (index->dtype.is_int() && index->dtype.bits() < 64) { if (index->dtype.is_int() && index->dtype.bits() < 64) {
auto int_bound = analyzer_->const_int_bound(index); auto int_bound = analyzer_->const_int_bound(index);
...@@ -130,10 +131,13 @@ private: ...@@ -130,10 +131,13 @@ private:
int_bound->min_value < -(1LL << (index->dtype.bits() - 1))) { int_bound->min_value < -(1LL << (index->dtype.bits() - 1))) {
Int64Promoter promoter; Int64Promoter promoter;
index = promoter(index); index = promoter(index);
new_indices.push_back(index);
continue;
} }
} }
new_indices.push_back(index);
} }
buffer_store.CopyOnWrite()->indices = indices; buffer_store.CopyOnWrite()->indices = new_indices;
return std::move(buffer_store); return std::move(buffer_store);
} }
...@@ -141,6 +145,7 @@ private: ...@@ -141,6 +145,7 @@ private:
auto buffer_load = auto buffer_load =
Downcast<BufferLoad>(IRMutatorWithAnalyzer::VisitExpr_(op)); Downcast<BufferLoad>(IRMutatorWithAnalyzer::VisitExpr_(op));
auto indices = buffer_load->indices; auto indices = buffer_load->indices;
Array<PrimExpr> new_indices;
for (auto index : indices) { for (auto index : indices) {
if (index->dtype.is_int() && index->dtype.bits() < 64) { if (index->dtype.is_int() && index->dtype.bits() < 64) {
auto int_bound = analyzer_->const_int_bound(index); auto int_bound = analyzer_->const_int_bound(index);
...@@ -148,10 +153,13 @@ private: ...@@ -148,10 +153,13 @@ private:
int_bound->min_value < -(1LL << (index->dtype.bits() - 1))) { int_bound->min_value < -(1LL << (index->dtype.bits() - 1))) {
Int64Promoter promoter; Int64Promoter promoter;
index = promoter(index); index = promoter(index);
new_indices.push_back(index);
continue;
} }
} }
new_indices.push_back(index);
} }
buffer_load.CopyOnWrite()->indices = indices; buffer_load.CopyOnWrite()->indices = new_indices;
return std::move(buffer_load); return std::move(buffer_load);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment