"git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "8a5eb569704bfea64478c29adcfe3a09e3c2b12c"
Unverified Commit fac04006 authored by ConvolutedDog's avatar ConvolutedDog Committed by GitHub
Browse files

[Feat] Extend LegalizeNegativeIndex to support buffer store stmts (#1339)

This commit enhances the LegalizeNegativeIndex transformation pass to handle
both buffer load and store operations with negative indices and adds some
test cases.
parent f810f976
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <tvm/ffi/cast.h> #include <tvm/ffi/cast.h>
#include <tvm/ffi/container/array.h> #include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/map.h> #include <tvm/ffi/container/map.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/memory.h> #include <tvm/ffi/memory.h>
#include <tvm/ffi/optional.h> #include <tvm/ffi/optional.h>
#include <tvm/ffi/string.h> #include <tvm/ffi/string.h>
......
/*! /*!
* \file legalize_negative_index.cc * \file legalize_negative_index.cc
* \brief Legalize negative indices in buffer load expressions. * \brief Legalize negative indices in buffer load/store expressions.
*/ */
#include <tvm/ffi/reflection/registry.h> #include <tvm/ffi/reflection/registry.h>
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <unordered_map> #include <unordered_map>
#include <variant>
#include <vector> #include <vector>
#include "arith/ir_mutator_with_analyzer.h" #include "arith/ir_mutator_with_analyzer.h"
...@@ -23,47 +24,42 @@ using arith::IRVisitorWithAnalyzer; ...@@ -23,47 +24,42 @@ using arith::IRVisitorWithAnalyzer;
enum class IndexSignState { kNonNegative, kNegative, kUnknown }; enum class IndexSignState { kNonNegative, kNegative, kUnknown };
using BufferAccessVariant =
std::variant<const BufferLoadNode *, const BufferStoreNode *>;
using LoadStore2StateMap =
std::unordered_map<BufferAccessVariant, std::vector<IndexSignState>>;
class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer { class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer {
public: public:
explicit NegativeIndexAnalyzer( explicit NegativeIndexAnalyzer(LoadStore2StateMap *result)
std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>>
*result)
: result_(result) {} : result_(result) {}
void VisitExpr_(const BufferLoadNode *op) final { private:
auto load = tvm::ffi::GetRef<BufferLoad>(op); std::vector<IndexSignState> ProcessIdx(const ffi::Array<PrimExpr> &indices,
ffi::String buffer_name) {
std::vector<IndexSignState> states; std::vector<IndexSignState> states;
states.reserve(op->indices.size()); states.reserve(indices.size());
bool needs_record = false;
for (size_t i = 0; i < op->indices.size(); ++i) { for (size_t i = 0; i < indices.size(); ++i) {
PrimExpr simplified = analyzer_.Simplify(op->indices[i]); PrimExpr simplified = analyzer_.Simplify(indices[i]);
IndexSignState state = IndexSignState::kUnknown;
// Handle scalar indices with the standard analyzer // Handle scalar indices with the standard analyzer
if (simplified.dtype().lanes() == 1) { if (simplified.dtype().lanes() == 1) {
if (analyzer_.CanProve(simplified >= 0)) { if (analyzer_.CanProve(simplified >= 0))
states.push_back(IndexSignState::kNonNegative); state = IndexSignState::kNonNegative;
continue; else if (analyzer_.CanProve(simplified < 0))
} state = IndexSignState::kNegative;
if (analyzer_.CanProve(simplified < 0)) { else
states.push_back(IndexSignState::kNegative); DLOG(WARNING)
needs_record = true; << "LegalizeNegativeIndex: cannot prove non-negative index "
continue; << simplified << " for buffer " << buffer_name << " (axis " << i
} << ", index " + indices[i]->Script() + ").";
states.push_back(IndexSignState::kUnknown);
needs_record = true;
DLOG(WARNING)
<< "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << load->buffer->name << " (axis "
<< i << ").";
continue;
} }
// Vector indices: try to reason about non-negativity/negativity // Vector indices: try to reason about non-negativity/negativity
// Common patterns are Ramp(base, stride, lanes) and Broadcast(value, // Common patterns are Ramp(base, stride, lanes) and Broadcast(value,
// lanes). // lanes).
IndexSignState vec_state = IndexSignState::kUnknown; else if (const auto *ramp = simplified.as<RampNode>()) {
if (const auto *ramp = simplified.as<RampNode>()) {
// Compute a safe lower/upper bound for the vector lanes // Compute a safe lower/upper bound for the vector lanes
// lower_bound = base_min + min(0, stride_min) * (lanes - 1) // lower_bound = base_min + min(0, stride_min) * (lanes - 1)
// upper_bound = base_max + max(0, stride_max) * (lanes - 1) // upper_bound = base_max + max(0, stride_max) * (lanes - 1)
...@@ -85,118 +81,129 @@ public: ...@@ -85,118 +81,129 @@ public:
if (s_max > 0) if (s_max > 0)
upper += s_max * (lanes - 1); upper += s_max * (lanes - 1);
if (lower >= 0) { if (lower >= 0)
vec_state = IndexSignState::kNonNegative; state = IndexSignState::kNonNegative;
} else if (upper < 0) { else if (upper < 0)
vec_state = IndexSignState::kNegative; state = IndexSignState::kNegative;
} else { else
vec_state = IndexSignState::kUnknown; DLOG(WARNING)
} << "LegalizeNegativeIndex: cannot prove non-negative index "
} else if (const auto *bc = simplified.as<BroadcastNode>()) { << simplified << " for buffer " << buffer_name << " (axis " << i
auto v = analyzer_.Simplify(bc->value); << ", index " + indices[i]->Script() + ").";
if (analyzer_.CanProve(v >= 0)) { } else if (const auto *broadcast = simplified.as<BroadcastNode>()) {
vec_state = IndexSignState::kNonNegative; auto v = analyzer_.Simplify(broadcast->value);
} else if (analyzer_.CanProve(v < 0)) { if (analyzer_.CanProve(v >= 0))
vec_state = IndexSignState::kNegative; state = IndexSignState::kNonNegative;
} else { else if (analyzer_.CanProve(v < 0))
state = IndexSignState::kNegative;
else {
// Try const bound if proof unavailable // Try const bound if proof unavailable
auto vb = analyzer_.const_int_bound(v); auto vb = analyzer_.const_int_bound(v);
if (vb->min_value >= 0) { if (vb->min_value >= 0)
vec_state = IndexSignState::kNonNegative; state = IndexSignState::kNonNegative;
} else if (vb->max_value < 0) { else if (vb->max_value < 0)
vec_state = IndexSignState::kNegative; state = IndexSignState::kNegative;
} else { else
vec_state = IndexSignState::kUnknown; DLOG(WARNING)
} << "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << buffer_name << " (axis " << i
<< ", index " + indices[i]->Script() + ").";
} }
} }
states.push_back(state);
}
if (vec_state == IndexSignState::kNonNegative) { return std::move(states);
states.push_back(IndexSignState::kNonNegative); }
continue;
}
if (vec_state == IndexSignState::kNegative) {
states.push_back(IndexSignState::kNegative);
needs_record = true;
continue;
}
states.push_back(IndexSignState::kUnknown); bool NeedRecord(const std::vector<IndexSignState> &states) {
needs_record = true; return std::any_of(states.begin(), states.end(),
DLOG(WARNING) << "LegalizeNegativeIndex: cannot prove non-negative index " [](const IndexSignState &state) {
<< simplified << " for buffer " << load->buffer->name return state == IndexSignState::kUnknown ||
<< " (axis " << i << ")."; state == IndexSignState::kNegative;
} });
}
void VisitExpr_(const BufferLoadNode *op) final {
std::vector<IndexSignState> states =
ProcessIdx(op->indices, op->buffer->name);
if (needs_record) { if (NeedRecord(states))
(*result_)[op] = std::move(states); (*result_)[op] = std::move(states);
}
IRVisitorWithAnalyzer::VisitExpr_(op); IRVisitorWithAnalyzer::VisitExpr_(op);
} }
void VisitStmt_(const BufferStoreNode *op) final {
std::vector<IndexSignState> states =
ProcessIdx(op->indices, op->buffer->name);
if (NeedRecord(states))
(*result_)[op] = std::move(states);
IRVisitorWithAnalyzer::VisitStmt_(op);
}
private: private:
std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>> LoadStore2StateMap *result_;
*result_;
}; };
class NegativeIndexRewriter : public arith::IRMutatorWithAnalyzer { class NegativeIndexRewriter : public arith::IRMutatorWithAnalyzer {
public: public:
static PrimFunc static PrimFunc Apply(PrimFunc func, const LoadStore2StateMap &states) {
Apply(PrimFunc func,
const std::unordered_map<const BufferLoadNode *,
std::vector<IndexSignState>> &states) {
arith::Analyzer analyzer; arith::Analyzer analyzer;
NegativeIndexRewriter rewriter(&analyzer, states); NegativeIndexRewriter rewriter(&analyzer, states);
if (!func->body.defined()) {
return func;
}
PrimFuncNode *func_node = func.CopyOnWrite(); PrimFuncNode *func_node = func.CopyOnWrite();
func_node->body = rewriter.VisitStmt(func_node->body); func_node->body = rewriter.VisitStmt(func_node->body);
return func; return func;
} }
private: private:
NegativeIndexRewriter( NegativeIndexRewriter(arith::Analyzer *analyzer,
arith::Analyzer *analyzer, const LoadStore2StateMap &states)
const std::unordered_map<const BufferLoadNode *,
std::vector<IndexSignState>> &states)
: arith::IRMutatorWithAnalyzer(analyzer), states_(states) {} : arith::IRMutatorWithAnalyzer(analyzer), states_(states) {}
ffi::Array<PrimExpr> UpdateIdx(const ffi::Array<PrimExpr> &indices,
const ffi::Array<PrimExpr> &buffer_shape,
const std::vector<IndexSignState> &state_vec) {
ICHECK_EQ(state_vec.size(), indices.size())
<< "State vector size mismatch for buffer load/store indices ("
<< indices << ")";
ffi::Array<PrimExpr> new_indices = indices;
for (size_t i = 0; i < indices.size(); ++i) {
if (state_vec[i] != IndexSignState::kNegative)
continue;
new_indices.Set(i, analyzer_->Simplify(buffer_shape[i] + indices[i]));
}
return new_indices;
}
PrimExpr VisitExpr_(const BufferLoadNode *op) final { PrimExpr VisitExpr_(const BufferLoadNode *op) final {
BufferLoad load = BufferLoad load =
Downcast<BufferLoad>(arith::IRMutatorWithAnalyzer::VisitExpr_(op)); Downcast<BufferLoad>(arith::IRMutatorWithAnalyzer::VisitExpr_(op));
auto it = states_.find(op); auto it = states_.find(op);
if (it == states_.end()) { if (it == states_.end())
return load; return load;
}
auto indices = load->indices; auto indices = UpdateIdx(load->indices, load->buffer->shape, it->second);
bool changed = false; return BufferLoad(load->buffer, indices, load->predicate);
}
const auto &state_vector = it->second;
ICHECK_EQ(state_vector.size(), indices.size())
<< "State vector size mismatch for buffer load " << load->buffer->name;
for (size_t i = 0; i < indices.size(); ++i) { Stmt VisitStmt_(const BufferStoreNode *op) final {
if (state_vector[i] != IndexSignState::kNegative) { BufferStore store =
continue; Downcast<BufferStore>(arith::IRMutatorWithAnalyzer::VisitStmt_(op));
}
PrimExpr extent = load->buffer->shape[i];
indices.Set(i, analyzer_->Simplify(extent + indices[i]));
changed = true;
}
if (!changed) { auto it = states_.find(op);
return load; if (it == states_.end())
} return store;
return BufferLoad(load->buffer, indices); auto indices = UpdateIdx(store->indices, store->buffer->shape, it->second);
return BufferStore(store->buffer, store->value, indices, store->predicate);
} }
const std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>> private:
&states_; const LoadStore2StateMap &states_;
}; };
PrimFunc LegalizeNegativeIndex(PrimFunc func) { PrimFunc LegalizeNegativeIndex(PrimFunc func) {
...@@ -204,8 +211,7 @@ PrimFunc LegalizeNegativeIndex(PrimFunc func) { ...@@ -204,8 +211,7 @@ PrimFunc LegalizeNegativeIndex(PrimFunc func) {
return func; return func;
} }
std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>> LoadStore2StateMap states;
states;
NegativeIndexAnalyzer analyzer(&states); NegativeIndexAnalyzer analyzer(&states);
analyzer(func->body); analyzer(func->body);
if (states.empty()) { if (states.empty()) {
......
from tilelang import tvm as tvm
import tilelang as tl
import tilelang.language as T
import tilelang.testing
def _check(original, expected):
"""Helper function to verify structural equality after transformations"""
func = original
mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
mod = tl.transform.LegalizeNegativeIndex()(mod)
expected = tvm.IRModule.from_expr(expected.with_attr("global_symbol", "main"))
tvm.ir.assert_structural_equal(mod["main"], expected["main"], True)
def test_buffer_load_negative_index_legalized():
"""
Test that negative indices are legalized by adding buffer extent.
"""
@T.prim_func
def before(A: T.Tensor((1024,), "float32")):
value = A[-1]
B = T.alloc_buffer((1,), "float32")
B[0] = value
@T.prim_func
def after(A: T.Tensor((1024,), "float32")):
value = A[1023] # A[-1] becomes A[1023]
B = T.alloc_buffer((1,), "float32")
B[0] = value
_check(before, after)
def test_buffer_load_mixed_negative_positive_indices():
"""
Test mixed negative and positive indices - only negative ones are legalized.
"""
@T.prim_func
def before(A: T.Tensor((1024, 512), "float32")):
value = A[-1, 10]
B = T.alloc_buffer((1,), "float32")
B[0] = value
@T.prim_func
def after(A: T.Tensor((1024, 512), "float32")):
value = A[1023, 10] # A[-1, 10] becomes A[1023, 10]
B = T.alloc_buffer((1,), "float32")
B[0] = value
_check(before, after)
def test_buffer_load_multiple_negative_indices():
"""
Test multiple negative indices in different dimensions.
"""
@T.prim_func
def before(A: T.Tensor((1024, 512, 256), "float32")):
value = A[-1, -2, -3]
B = T.alloc_buffer((1,), "float32")
B[0] = value
@T.prim_func
def after(A: T.Tensor((1024, 512, 256), "float32")):
value = A[1023, 510, 253] # -1+1024=1023, -2+512=510, -3+256=253
B = T.alloc_buffer((1,), "float32")
B[0] = value
_check(before, after)
def test_buffer_load_negative_index_in_expression():
"""
Test negative index as part of a larger expression.
"""
@T.prim_func
def before(A: T.Tensor((1024,), "float32")):
B = T.alloc_buffer((1024,), "float32")
for i in T.serial(1, 1024):
value = A[-i]
B[-i] = value
@T.prim_func
def after(A: T.Tensor((1024,), "float32")):
B = T.alloc_buffer((1024,), "float32")
for i in T.serial(1, 1024):
value = A[1024 - i]
B[1024 - i] = value
_check(before, after)
def test_buffer_load_non_negative_index_unchanged():
"""
Test that non-negative indices remain unchanged.
"""
@T.prim_func
def before(A: T.Tensor((1024,), "float32")):
value = A[0]
B = T.alloc_buffer((1,), "float32")
B[0] = value
@T.prim_func
def after(A: T.Tensor((1024,), "float32")):
# No changes expected for non-negative indices
value = A[0]
B = T.alloc_buffer((1,), "float32")
B[0] = value
_check(before, after)
def test_buffer_load_unknown_sign_index_warning():
"""
Test that indices with unknown sign trigger warnings but are processed.
This test mainly checks that the pass doesn't crash on unknown signs.
"""
@T.prim_func
def before(A: T.Tensor((1024,), "float32")):
i = T.Var("i", "int32")
value = A[i]
B = T.alloc_buffer((1,), "float32")
B[0] = value
@T.prim_func
def after(A: T.Tensor((1024,), "float32")):
i = T.Var("i", "int32")
# Unknown sign indices should remain unchanged
value = A[i]
B = T.alloc_buffer((1,), "float32")
B[0] = value
_check(before, after)
def test_buffer_load_vector_index_negative_broadcast():
"""
Test negative indices in vectorized operations (broadcast case).
"""
@T.prim_func
def before(A: T.Tensor((1024,), "float32")):
vec = T.Broadcast(-1, 4)
value = A[vec]
B = T.alloc_buffer((4,), "float32")
B[T.Ramp(0, 1, 4)] = value
@T.prim_func
def after(A: T.Tensor((1024,), "float32")):
# vec is unused and can be delimed by Simplify.
vec = T.Broadcast(-1, 4) # noqa: F841
value = A[T.Broadcast(1023, 4)]
B = T.alloc_buffer((4,), "float32")
B[T.Ramp(0, 1, 4)] = value
_check(before, after)
def test_buffer_load_vector_index_negative_ramp():
"""
Test negative indices in vectorized operations (ramp case).
"""
@T.prim_func
def before(A: T.Tensor((1024,), "float32")):
vec = T.Ramp(-4, 1, 4) # indices: [-4, -3, -2, -1]
value = A[vec]
B = T.alloc_buffer((4,), "float32")
B[T.Ramp(0, 1, 4)] = value
@T.prim_func
def after(A: T.Tensor((1024,), "float32")):
# vec is unused and can be delimed by Simplify.
vec = T.Ramp(-4, 1, 4) # noqa: F841
value = A[T.Ramp(1020, 1, 4)]
B = T.alloc_buffer((4,), "float32")
B[T.Ramp(0, 1, 4)] = value
_check(before, after)
def test_buffer_load_nested_buffer_loads():
"""
Test legalization with nested buffer load expressions.
"""
@T.prim_func
def before(A: T.Tensor((1024, 512), "float32")):
inner_val = A[-1, 10]
outer_val = A[inner_val.astype("int32"), -2]
B = T.alloc_buffer((1,), "float32")
B[0] = outer_val
@T.prim_func
def after(A: T.Tensor((1024, 512), "float32")):
inner_val = A[1023, 10]
outer_val = A[inner_val.astype("int32"), 510]
B = T.alloc_buffer((1,), "float32")
B[0] = outer_val
_check(before, after)
def test_buffer_store_negative_index():
"""
Test negative indices in buffer store operations are legalized.
"""
@T.prim_func
def before(A: T.Tensor((1024,), "float32")):
A[-1] = 42.0
@T.prim_func
def after(A: T.Tensor((1024,), "float32")):
A[1023] = 42.0
_check(before, after)
def test_buffer_store_mixed_negative_positive_indices():
"""
Test mixed negative and positive indices in buffer store.
"""
@T.prim_func
def before(A: T.Tensor((1024, 512), "float32")):
A[-1, 10] = 42.0
@T.prim_func
def after(A: T.Tensor((1024, 512), "float32")):
A[1023, 10] = 42.0
_check(before, after)
def test_buffer_store_multiple_negative_indices():
"""
Test multiple negative indices in different dimensions for buffer store.
"""
@T.prim_func
def before(A: T.Tensor((1024, 512, 256), "float32")):
A[-1, -2, -3] = 42.0
@T.prim_func
def after(A: T.Tensor((1024, 512, 256), "float32")):
A[1023, 510, 253] = 42.0 # -1+1024=1023, -2+512=510, -3+256=253
_check(before, after)
def test_buffer_store_negative_index_in_expression():
"""
Test negative index as part of a larger expression in buffer store.
"""
@T.prim_func
def before(A: T.Tensor((1024,), "float32")):
for i in T.serial(1, 1024):
A[-i] = i * 2.0
@T.prim_func
def after(A: T.Tensor((1024,), "float32")):
for i in T.serial(1, 1024):
A[1024 - i] = i * 2.0
_check(before, after)
def test_buffer_store_vector_index_negative_broadcast():
"""
Test negative indices in vectorized store operations (broadcast case).
"""
@T.prim_func
def before(A: T.Tensor((1024,), "float32")):
vec = T.Broadcast(-1, 4)
values = T.Broadcast(42.0, 4)
A[vec] = values
@T.prim_func
def after(A: T.Tensor((1024,), "float32")):
# vec is unused and can be delimed by Simplify.
vec = T.Broadcast(-1, 4) # noqa: F841
values = T.Broadcast(42.0, 4)
A[T.Broadcast(1023, 4)] = values
_check(before, after)
def test_buffer_store_vector_index_negative_ramp():
"""
Test negative indices in vectorized store operations (ramp case).
"""
@T.prim_func
def before(A: T.Tensor((1024,), "float32")):
vec = T.Ramp(-4, 1, 4) # indices: [-4, -3, -2, -1]
values = T.Ramp(0.0, 1.0, 4) # values: [0.0, 1.0, 2.0, 3.0]
A[vec] = values
@T.prim_func
def after(A: T.Tensor((1024,), "float32")):
# vec is unused and can be delimed by Simplify.
vec = T.Ramp(-4, 1, 4) # noqa: F841
values = T.Ramp(0.0, 1.0, 4)
A[T.Ramp(1020, 1, 4)] = values
_check(before, after)
def test_buffer_store_nested_in_condition():
"""
Test negative index buffer store within conditional statements.
"""
@T.prim_func
def before(A: T.Tensor((1024,), "float32"), flag: T.int32):
if flag > 0:
A[-1] = 42.0
else:
A[-2] = 24.0
@T.prim_func
def after(A: T.Tensor((1024,), "float32"), flag: T.int32):
if flag > 0:
A[1023] = 42.0
else:
A[1022] = 24.0
_check(before, after)
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment