"docs/source/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "006e1a0190057e94905987712ca245bf80ec09d0"
Unverified Commit a9611738 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Index] Relocate Int64 Auto Promoter to ConfigBitWidth Pass, removing it from FlattenBuffer (#714)

* Update submodule 'tvm' to commit e11521e6936a827efa334588d29571fbb4620107

* Refactor inject_pipeline.cc to enhance pipeline body rewriting and condition handling

- Introduced a new function to replace IfThenElse nodes with their then_case while preserving attributes.
- Streamlined the PipelineBodyRewriter to improve buffer access rewriting and async state management.
- Enhanced the handling of pipeline loop conditions and added support for predicate conditions in the pipeline body.
- Removed obsolete code and improved overall code clarity and maintainability.

* lint fix

* Refactor return statements in inject_pipeline.cc to remove unnecessary std::move calls

- Updated return statements in multiple methods to return objects directly instead of using std::move, improving code clarity and potentially avoiding unnecessary moves.
- Ensured consistent handling of BufferStore and BufferLoad nodes during pipeline transformations.

* test fix

* Enhance global read detection in pipeline planning

- Updated the handling of global reads to account for condition expressions within IfThenElse nodes, ensuring accurate identification of global memory accesses.
- Introduced a new flag to track whether the visitor is within a condition expression, improving the correctness of buffer access analysis.
- Refactored the VisitStmt_ method to properly handle the structure of IfThenElse nodes, enhancing the clarity and maintainability of the code.

* Add IndexLegalizer to enforce int64 for out-of-bound indices

- Introduced the IndexLegalizer class to ensure that indices in BufferStore and BufferLoad nodes are promoted to int64 when they exceed their type bounds.
- Refactored the Int64Promoter logic from flatten_buffer.cc into IndexLegalizer, improving code organization and reusability.
- Updated the ConfigIndexBitwidth pass to apply IndexLegalizer after rewriting the body, enhancing the handling of index bitwidths in transformations.
parent c1eef511
#include "../op/builtin.h" #include "../op/builtin.h"
#include "arith/ir_mutator_with_analyzer.h"
#include <tvm/ffi/function.h> #include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h> #include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
...@@ -10,6 +11,7 @@ namespace tvm { ...@@ -10,6 +11,7 @@ namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
using namespace arith;
class ConfigIndexBitwidthRewriter : public IndexDataTypeRewriter { class ConfigIndexBitwidthRewriter : public IndexDataTypeRewriter {
public: public:
using Parent = IndexDataTypeRewriter; using Parent = IndexDataTypeRewriter;
...@@ -68,6 +70,92 @@ protected: ...@@ -68,6 +70,92 @@ protected:
int _index_bitwidth_; int _index_bitwidth_;
}; };
class IndexLegalizer : public IRMutatorWithAnalyzer {
public:
static Stmt Rewrite(Stmt stmt) {
Analyzer ana;
auto pass = IndexLegalizer(&ana);
return pass.VisitStmt(stmt);
}
private:
explicit IndexLegalizer(arith::Analyzer *ana) : IRMutatorWithAnalyzer(ana) {}
class Int64Promoter : public IndexDataTypeRewriter {
public:
using Parent = IndexDataTypeRewriter;
PrimExpr VisitExpr_(const VarNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), GetRef<Var>(op));
}
return GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const IntImmNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return IntImm(DataType::Int(64), op->value);
}
return GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const CastNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), op->value);
}
return GetRef<PrimExpr>(op);
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
// Force indices to be int64
auto node = Downcast<BufferStore>(Parent::VisitStmt_(op));
return std::move(node);
}
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto node = Downcast<BufferLoad>(Parent::VisitExpr_(op));
return std::move(node);
}
};
Stmt VisitStmt_(const BufferStoreNode *op) final {
auto buffer_store =
Downcast<BufferStore>(IRMutatorWithAnalyzer::VisitStmt_(op));
auto indices = buffer_store->indices;
for (auto index : indices) {
if (index->dtype.is_int() && index->dtype.bits() < 64) {
auto int_bound = analyzer_->const_int_bound(index);
if (int_bound->max_value >= (1LL << (index->dtype.bits() - 1)) - 1 ||
int_bound->min_value < -(1LL << (index->dtype.bits() - 1))) {
Int64Promoter promoter;
index = promoter(index);
}
}
}
buffer_store.CopyOnWrite()->indices = indices;
return std::move(buffer_store);
}
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto buffer_load =
Downcast<BufferLoad>(IRMutatorWithAnalyzer::VisitExpr_(op));
auto indices = buffer_load->indices;
for (auto index : indices) {
if (index->dtype.is_int() && index->dtype.bits() < 64) {
auto int_bound = analyzer_->const_int_bound(index);
if (int_bound->max_value >= (1LL << (index->dtype.bits() - 1)) - 1 ||
int_bound->min_value < -(1LL << (index->dtype.bits() - 1))) {
Int64Promoter promoter;
index = promoter(index);
}
}
}
buffer_load.CopyOnWrite()->indices = indices;
return std::move(buffer_load);
}
};
tvm::transform::Pass ConfigIndexBitwidth() { tvm::transform::Pass ConfigIndexBitwidth() {
using namespace tir::transform; using namespace tir::transform;
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
...@@ -81,6 +169,8 @@ tvm::transform::Pass ConfigIndexBitwidth() { ...@@ -81,6 +169,8 @@ tvm::transform::Pass ConfigIndexBitwidth() {
n->body = ConfigIndexBitwidthRewriter(config_index_bitwidth)( n->body = ConfigIndexBitwidthRewriter(config_index_bitwidth)(
std::move(n->body)); std::move(n->body));
} }
// Legalize out-of-bound indices to be int64
n->body = IndexLegalizer::Rewrite(std::move(n->body));
return f; return f;
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.ConfigIndexBitwidth", {}); return CreatePrimFuncPass(pass_func, 0, "tl.ConfigIndexBitwidth", {});
......
...@@ -60,43 +60,6 @@ private: ...@@ -60,43 +60,6 @@ private:
using IRMutatorWithAnalyzer::VisitStmt; using IRMutatorWithAnalyzer::VisitStmt;
using IRMutatorWithAnalyzer::VisitStmt_; using IRMutatorWithAnalyzer::VisitStmt_;
class Int64Promoter : public tir::IndexDataTypeRewriter {
public:
using Parent = IndexDataTypeRewriter;
PrimExpr VisitExpr_(const VarNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), GetRef<Var>(op));
}
return GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const IntImmNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return IntImm(DataType::Int(64), op->value);
}
return GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const CastNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), op->value);
}
return GetRef<PrimExpr>(op);
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
// Force indices to be int64
auto node = Downcast<BufferStore>(Parent::VisitStmt_(op));
return std::move(node);
}
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto node = Downcast<BufferLoad>(Parent::VisitExpr_(op));
return std::move(node);
}
};
explicit BufferFlattener(arith::Analyzer *ana) : IRMutatorWithAnalyzer(ana) {} explicit BufferFlattener(arith::Analyzer *ana) : IRMutatorWithAnalyzer(ana) {}
Stmt VisitStmt_(const BlockNode *op) final { Stmt VisitStmt_(const BlockNode *op) final {
...@@ -277,29 +240,7 @@ private: ...@@ -277,29 +240,7 @@ private:
Array<PrimExpr> GetSimplifiedElemOffset(const Buffer &buffer, Array<PrimExpr> GetSimplifiedElemOffset(const Buffer &buffer,
const Array<PrimExpr> &indices) { const Array<PrimExpr> &indices) {
auto flattened_indices = buffer->ElemOffset(indices); auto flattened_indices = buffer->ElemOffset(indices);
Array<PrimExpr> safe_indices; return this->IterMapSimplifyWithContext(flattened_indices, false);
for (auto index : flattened_indices) {
auto int_bound = analyzer_->const_int_bound(index);
DataType dtype = index->dtype;
if (dtype.is_int() && dtype.bits() < 64) {
int64_t max_value = int_bound->max_value;
int64_t min_value = int_bound->min_value;
const int64_t type_max = (1LL << (dtype.bits() - 1));
const int64_t type_min = -(1LL << (dtype.bits() - 1));
if (max_value >= (type_max - 1) || min_value < type_min) {
Int64Promoter promoter;
for (auto &index : flattened_indices) {
safe_indices.push_back(promoter(index));
}
} else {
safe_indices.push_back(index);
}
} else {
safe_indices.push_back(index);
}
}
return this->IterMapSimplifyWithContext(safe_indices, false);
} }
template <typename Node> Node VisitBufferAccess(Node node) { template <typename Node> Node VisitBufferAccess(Node node) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment