Unverified Commit b39aaf5b authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix][WS] Consider loop min extent when computing phase id (#754)

* Update test parameters and remove debug print statement

- Adjusted test cases in `test_tilelang_dynamic_symbolic_bench.py` to use smaller matrix sizes (1024x1024) for improved performance and quicker execution.
- Removed a debug print statement from `phase.py` to clean up the code and enhance clarity.

* Refactor loop stack management in warp_specialized_rewriter

- Introduced a new `LoopInfo` struct to encapsulate loop variable details, including `loop_var`, `extent`, and `min`, enhancing clarity and maintainability.
- Updated the `loop_stack_` to utilize `LoopInfo` instead of a pair, improving type safety and readability.
- Adjusted linear index calculations to account for the new structure, ensuring correct behavior in loop transformations.
parent fd199a4a
...@@ -24,6 +24,12 @@ using namespace tir; ...@@ -24,6 +24,12 @@ using namespace tir;
using namespace runtime; using namespace runtime;
using arith::IRVisitorWithAnalyzer; using arith::IRVisitorWithAnalyzer;
struct LoopInfo {
Var loop_var;
PrimExpr extent;
PrimExpr min;
};
enum class Role { kConsumer, kProducer, kBoth }; enum class Role { kConsumer, kProducer, kBoth };
class ProducerBufferDetector : public StmtExprVisitor { class ProducerBufferDetector : public StmtExprVisitor {
...@@ -838,7 +844,7 @@ private: ...@@ -838,7 +844,7 @@ private:
num_stages = static_cast<int>(num_stages_anno->as<IntImmNode>()->value); num_stages = static_cast<int>(num_stages_anno->as<IntImmNode>()->value);
ICHECK(num_stages_ == 1) << "Nested pipeline not supported."; ICHECK(num_stages_ == 1) << "Nested pipeline not supported.";
} }
loop_stack_.emplace_back(op->loop_var, op->extent); loop_stack_.emplace_back(LoopInfo{op->loop_var, op->extent, op->min});
Array<Array<Integer>> group_info_array; Array<Array<Integer>> group_info_array;
Array<Integer> order_info_array; Array<Integer> order_info_array;
...@@ -871,15 +877,14 @@ private: ...@@ -871,15 +877,14 @@ private:
num_stages_ = num_stages; num_stages_ = num_stages;
pipeline_info_ = pipeline_info; pipeline_info_ = pipeline_info;
PrimExpr linear_index = loop_stack_[0].first; PrimExpr linear_index = loop_stack_[0].loop_var - loop_stack_[0].min;
for (size_t i = 1; i < loop_stack_.size(); ++i) { for (size_t i = 1; i < loop_stack_.size(); ++i) {
linear_index = linear_index = linear_index * loop_stack_[i].extent +
linear_index * loop_stack_[i].second + loop_stack_[i].first; (loop_stack_[i].loop_var - loop_stack_[i].min);
} }
stage_ = FloorMod(linear_index, num_stages); stage_ = FloorMod(linear_index, num_stages);
parity_ = FloorMod( parity_ = FloorMod(
parity_before * op->extent + FloorDiv(linear_index, num_stages), 2); parity_before * op->extent + FloorDiv(linear_index, num_stages), 2);
auto result = FilterByRole(op); auto result = FilterByRole(op);
Stmt grouped_for_node; Stmt grouped_for_node;
...@@ -1137,7 +1142,7 @@ private: ...@@ -1137,7 +1142,7 @@ private:
PrimExpr parity_ = 0; PrimExpr parity_ = 0;
PrimExpr stage_ = 0; PrimExpr stage_ = 0;
int num_stages_ = 1; int num_stages_ = 1;
std::vector<std::pair<Var, PrimExpr>> loop_stack_; std::vector<LoopInfo> loop_stack_;
Var thread_var_; Var thread_var_;
bool mbarrier_only_ = false; bool mbarrier_only_ = false;
PipelineInfo pipeline_info_; PipelineInfo pipeline_info_;
......
...@@ -550,10 +550,10 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K): ...@@ -550,10 +550,10 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
def test_all(): def test_all():
run_assert_tl_matmul_block_static(16384, 16384, 16384, 128, 128, 32) run_assert_tl_matmul_block_static(1024, 1024, 1024, 128, 128, 32)
run_assert_tl_matmul_block_dynamic_m(16384, 16384, 16384, 128, 128, 32) run_assert_tl_matmul_block_dynamic_m(1024, 1024, 1024, 128, 128, 32)
run_assert_tl_matmul_block_dynamic_mn(16384, 16384, 16384, 128, 128, 32) run_assert_tl_matmul_block_dynamic_mn(1024, 1024, 1024, 128, 128, 32)
run_assert_tl_matmul_block_dynamic_mnk(16384, 16384, 16384, 128, 128, 32) run_assert_tl_matmul_block_dynamic_mnk(1024, 1024, 1024, 128, 128, 32)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -165,7 +165,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -165,7 +165,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.MergeSharedMemoryAllocations( mod = tilelang.transform.MergeSharedMemoryAllocations(
enable_aggressive_merge=enable_aggressive_merge)( enable_aggressive_merge=enable_aggressive_merge)(
mod) mod)
print("mod \n", mod)
mod = tilelang.transform.ThreadSync("shared")(mod) mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod) mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
# Inject PTX async copy must behind the thread sync pass # Inject PTX async copy must behind the thread sync pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment