Unverified Commit b39aaf5b authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix][WS] Consider loop min extent when computing phase id (#754)

* Update test parameters and remove debug print statement

- Adjusted test cases in `test_tilelang_dynamic_symbolic_bench.py` to use smaller matrix sizes (1024x1024) for improved performance and quicker execution.
- Removed a debug print statement from `phase.py` to clean up the code and enhance clarity.

* Refactor loop stack management in warp_specialized_rewriter

- Introduced a new `LoopInfo` struct to encapsulate loop variable details, including `loop_var`, `extent`, and `min`, enhancing clarity and maintainability.
- Updated the `loop_stack_` to utilize `LoopInfo` instead of a pair, improving type safety and readability.
- Adjusted linear index calculations to account for the new structure, ensuring correct behavior in loop transformations.
parent fd199a4a
......@@ -24,6 +24,12 @@ using namespace tir;
using namespace runtime;
using arith::IRVisitorWithAnalyzer;
struct LoopInfo {
Var loop_var;
PrimExpr extent;
PrimExpr min;
};
enum class Role { kConsumer, kProducer, kBoth };
class ProducerBufferDetector : public StmtExprVisitor {
......@@ -838,7 +844,7 @@ private:
num_stages = static_cast<int>(num_stages_anno->as<IntImmNode>()->value);
ICHECK(num_stages_ == 1) << "Nested pipeline not supported.";
}
loop_stack_.emplace_back(op->loop_var, op->extent);
loop_stack_.emplace_back(LoopInfo{op->loop_var, op->extent, op->min});
Array<Array<Integer>> group_info_array;
Array<Integer> order_info_array;
......@@ -871,15 +877,14 @@ private:
num_stages_ = num_stages;
pipeline_info_ = pipeline_info;
PrimExpr linear_index = loop_stack_[0].first;
PrimExpr linear_index = loop_stack_[0].loop_var - loop_stack_[0].min;
for (size_t i = 1; i < loop_stack_.size(); ++i) {
linear_index =
linear_index * loop_stack_[i].second + loop_stack_[i].first;
linear_index = linear_index * loop_stack_[i].extent +
(loop_stack_[i].loop_var - loop_stack_[i].min);
}
stage_ = FloorMod(linear_index, num_stages);
parity_ = FloorMod(
parity_before * op->extent + FloorDiv(linear_index, num_stages), 2);
auto result = FilterByRole(op);
Stmt grouped_for_node;
......@@ -1137,7 +1142,7 @@ private:
PrimExpr parity_ = 0;
PrimExpr stage_ = 0;
int num_stages_ = 1;
std::vector<std::pair<Var, PrimExpr>> loop_stack_;
std::vector<LoopInfo> loop_stack_;
Var thread_var_;
bool mbarrier_only_ = false;
PipelineInfo pipeline_info_;
......
......@@ -550,10 +550,10 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
def test_all():
run_assert_tl_matmul_block_static(16384, 16384, 16384, 128, 128, 32)
run_assert_tl_matmul_block_dynamic_m(16384, 16384, 16384, 128, 128, 32)
run_assert_tl_matmul_block_dynamic_mn(16384, 16384, 16384, 128, 128, 32)
run_assert_tl_matmul_block_dynamic_mnk(16384, 16384, 16384, 128, 128, 32)
run_assert_tl_matmul_block_static(1024, 1024, 1024, 128, 128, 32)
run_assert_tl_matmul_block_dynamic_m(1024, 1024, 1024, 128, 128, 32)
run_assert_tl_matmul_block_dynamic_mn(1024, 1024, 1024, 128, 128, 32)
run_assert_tl_matmul_block_dynamic_mnk(1024, 1024, 1024, 128, 128, 32)
if __name__ == "__main__":
......
......@@ -165,7 +165,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.MergeSharedMemoryAllocations(
enable_aggressive_merge=enable_aggressive_merge)(
mod)
print("mod \n", mod)
mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
# Inject PTX async copy must behind the thread sync pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment