"examples/git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "6972aed7289e2a817a1221c9dcd0771fcf14084e"
Unverified Commit 0592834f authored by Kurisu's avatar Kurisu Committed by GitHub
Browse files

[Feat] Add A Pass to Handle Negative Index (#1192)

parent 777881e1
/*!
* \file legalize_negative_index.cc
* \brief Legalize negative indices in buffer load expressions.
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/logging.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_map>
#include <vector>
#include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h"
namespace tvm {
namespace tl {
using namespace tir;
using arith::IRVisitorWithAnalyzer;
enum class IndexSignState { kNonNegative, kNegative, kUnknown };
class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer {
public:
explicit NegativeIndexAnalyzer(
std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>>
*result)
: result_(result) {}
void VisitExpr_(const BufferLoadNode *op) final {
auto load = tvm::ffi::GetRef<BufferLoad>(op);
std::vector<IndexSignState> states;
states.reserve(op->indices.size());
bool needs_record = false;
for (size_t i = 0; i < op->indices.size(); ++i) {
PrimExpr simplified = analyzer_.Simplify(op->indices[i]);
if (analyzer_.CanProve(simplified >= 0)) {
states.push_back(IndexSignState::kNonNegative);
continue;
}
if (analyzer_.CanProve(simplified < 0)) {
states.push_back(IndexSignState::kNegative);
needs_record = true;
continue;
}
states.push_back(IndexSignState::kUnknown);
needs_record = true;
LOG(WARNING) << "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << load->buffer->name
<< " (axis " << i << ").";
}
if (needs_record) {
(*result_)[op] = std::move(states);
}
IRVisitorWithAnalyzer::VisitExpr_(op);
}
private:
std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>>
*result_;
};
class NegativeIndexRewriter : public arith::IRMutatorWithAnalyzer {
public:
static PrimFunc
Apply(PrimFunc func,
const std::unordered_map<const BufferLoadNode *,
std::vector<IndexSignState>> &states) {
arith::Analyzer analyzer;
NegativeIndexRewriter rewriter(&analyzer, states);
if (!func->body.defined()) {
return func;
}
PrimFuncNode *func_node = func.CopyOnWrite();
func_node->body = rewriter.VisitStmt(func_node->body);
return func;
}
private:
NegativeIndexRewriter(
arith::Analyzer *analyzer,
const std::unordered_map<const BufferLoadNode *,
std::vector<IndexSignState>> &states)
: arith::IRMutatorWithAnalyzer(analyzer), states_(states) {}
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
BufferLoad load =
Downcast<BufferLoad>(arith::IRMutatorWithAnalyzer::VisitExpr_(op));
auto it = states_.find(op);
if (it == states_.end()) {
return load;
}
auto indices = load->indices;
bool changed = false;
const auto &state_vector = it->second;
ICHECK_EQ(state_vector.size(), indices.size())
<< "State vector size mismatch for buffer load " << load->buffer->name;
for (size_t i = 0; i < indices.size(); ++i) {
if (state_vector[i] != IndexSignState::kNegative) {
continue;
}
PrimExpr extent = load->buffer->shape[i];
indices.Set(i, analyzer_->Simplify(extent + indices[i]));
changed = true;
}
if (!changed) {
return load;
}
return BufferLoad(load->buffer, indices);
}
const std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>>
&states_;
};
PrimFunc LegalizeNegativeIndex(PrimFunc func) {
if (!func->body.defined()) {
return func;
}
std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>>
states;
NegativeIndexAnalyzer analyzer(&states);
analyzer(func->body);
if (states.empty()) {
return func;
}
return NegativeIndexRewriter::Apply(std::move(func), states);
}
tvm::transform::Pass LegalizeNegativeIndexPass() {
using namespace tir::transform;
auto pass_func = [](PrimFunc f, const IRModule &, PassContext) {
return LegalizeNegativeIndex(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tl.LegalizeNegativeIndex", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LegalizeNegativeIndex",
LegalizeNegativeIndexPass);
}
} // namespace tl
} // namespace tvm
from tilelang import tvm
import tilelang as tl
import tilelang.testing
from tvm.script import tir as T
@T.prim_func
def negative_index_before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")):
T.func_attr({"tir.noalias": True})
B[0] = A[T.int32(-1)]
@T.prim_func
def negative_index_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")):
T.func_attr({"tir.noalias": True})
B[0] = A[T.int32(15)]
@T.prim_func
def negative_index_loop_before(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")):
T.func_attr({"tir.noalias": True})
for i in T.serial(4):
B[i] = A[-i - 1]
@T.prim_func
def negative_index_loop_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")):
T.func_attr({"tir.noalias": True})
for i in T.serial(4):
B[i] = A[15 - i]
@T.prim_func
def negative_index_symbolic_before(shift: T.int32, A: T.Buffer((16,), "float32"),
B: T.Buffer((16,), "float32")):
T.func_attr({"tir.noalias": True})
for i in T.serial(16):
B[i] = A[shift + i]
def test_legalize_negative_index_scalar():
mod = tvm.IRModule({"main": negative_index_before})
transformed = tl.transform.LegalizeNegativeIndex()(mod)
tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_expected.body)
def test_legalize_negative_index_affine_expr():
mod = tvm.IRModule({"main": negative_index_loop_before})
transformed = tl.transform.LegalizeNegativeIndex()(mod)
tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_loop_expected.body)
def test_legalize_negative_index_symbolic_passthrough():
mod = tvm.IRModule({"main": negative_index_symbolic_before})
transformed = tl.transform.LegalizeNegativeIndex()(mod)
tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_symbolic_before.body)
if __name__ == "__main__":
tilelang.testing.main()
...@@ -96,6 +96,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: ...@@ -96,6 +96,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.LetInline()(mod) mod = tilelang.transform.LetInline()(mod)
# Add wrapper for single buf store # Add wrapper for single buf store
mod = tilelang.transform.AddWrapperForSingleBufStore()(mod) mod = tilelang.transform.AddWrapperForSingleBufStore()(mod)
# Normalize negative indices to canonical non-negative form
mod = tilelang.transform.LegalizeNegativeIndex()(mod)
# Inject assumes to speedup tvm prover # Inject assumes to speedup tvm prover
mod = tilelang.transform.InjectAssumes()(mod) mod = tilelang.transform.InjectAssumes()(mod)
# Simplify the IR expressions # Simplify the IR expressions
......
...@@ -80,6 +80,17 @@ def FrontendLegalize(): ...@@ -80,6 +80,17 @@ def FrontendLegalize():
return _ffi_api.FrontendLegalize() # type: ignore return _ffi_api.FrontendLegalize() # type: ignore
def LegalizeNegativeIndex():
"""Legalize negative indices in buffer loads.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LegalizeNegativeIndex() # type: ignore
def InjectAssumes(): def InjectAssumes():
"""Inject Assumes """Inject Assumes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment