Commit 67b81609 authored by xs-keju's avatar xs-keju Committed by LeiWang1999
Browse files

[Refactor] Add parallel loop transform pass for condition extraction (#618)



* [Refactor] Add parallel loop transform

* done format check

* pull 3rdparty repo

* Refactor loop variable handling in transformation utilities

- Updated the logic in `loop_parallel_transform_utils.h` to simplify the handling of related loop variables.
- Removed the check that enforced a single related loop variable, replacing it with a return statement when multiple variables are detected, enhancing clarity and maintainability of the transformation process.

* Update loop_parallel_transform_utils.h

* Refactor loop variable handling in transformation utilities

- Enhanced the logic in `loop_parallel_transform_utils.h` to improve clarity and maintainability by simplifying the handling of related loop variables.
- Replaced the previous enforcement of a single related loop variable with a return statement for multiple variables detected.

* remove disable cache flag as commit id has been key component

* lint fix

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent 4fa933c7
...@@ -73,5 +73,4 @@ jobs: ...@@ -73,5 +73,4 @@ jobs:
run: | run: |
source tilelang_ci/bin/activate source tilelang_ci/bin/activate
cd testing/python cd testing/python
export TILELANG_CLEAR_CACHE=1
python -m pytest python -m pytest
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "../target/utils.h" #include "../target/utils.h"
#include "../transform/common/loop_fusion_utils.h" #include "../transform/common/loop_fusion_utils.h"
#include "../transform/common/loop_parallel_transform_utils.h"
#include "../transform/loop_partition.h" #include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h" #include "../transform/loop_vectorize.h"
#include "builtin.h" #include "builtin.h"
...@@ -164,11 +165,14 @@ Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -164,11 +165,14 @@ Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto simt_loop = MakeSIMTLoop(analyzer); auto simt_loop = MakeSIMTLoop(analyzer);
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop)); auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
auto transformed_loop =
Downcast<For>(ParallelLoopTransformer::Substitute(fused_loop));
For vectorized_thread_loop; For vectorized_thread_loop;
auto par_op = std::make_unique<ParallelOp>(fused_loop); auto par_op = std::make_unique<ParallelOp>(transformed_loop);
if (is_cpu_target) { if (is_cpu_target) {
vectorized_thread_loop = VectorizeLoop(fused_loop); vectorized_thread_loop = VectorizeLoop(transformed_loop);
} else { } else {
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict, std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
InferLevel::kFree}; InferLevel::kFree};
......
/*!
* \file common.h
* \brief Common utilities for TL transforms
*/
#include <tvm/arith/analyzer.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/index_map.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>
#include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h"
#include <queue>
namespace tvm {
namespace tl {
using namespace tir;
using arith::IRMutatorWithAnalyzer;
using arith::IRVisitorWithAnalyzer;
class ParallelLoopTransformer : public IRMutatorWithAnalyzer {
public:
static Stmt Substitute(Stmt stmt, bool skip_thread_partition = false) {
arith::Analyzer analyzer;
ParallelLoopTransformer transformer(&analyzer);
return transformer.VisitStmt(stmt);
}
ParallelLoopTransformer(arith::Analyzer *analyzer)
: IRMutatorWithAnalyzer(analyzer) {}
Stmt VisitStmt_(const ForNode *op) final {
if (op->kind != ForKind::kParallel)
return StmtMutator::VisitStmt_(op);
// Collect loop variables and ranges
auto for_node = GetRef<For>(op);
Array<Var> loop_vars;
Array<PrimExpr> loop_extents;
Stmt body = op->body;
// Bind the range of outer loop variables
analyzer_->Bind(op->loop_var, Range::FromMinExtent(0, op->extent));
loop_vars.push_back(op->loop_var);
loop_extents.push_back(op->extent);
// If there are inner loops, bind their ranges as well
while (const ForNode *inner = body.as<ForNode>()) {
analyzer_->Bind(inner->loop_var, Range::FromMinExtent(0, inner->extent));
loop_vars.push_back(inner->loop_var);
loop_extents.push_back(inner->extent);
body = inner->body;
}
ICHECK(loop_vars.size() == loop_extents.size())
<< "loop_vars and loop_extents size mismatch";
// Collect buffer access information
BufferAccessCollector collector;
collector(op->body);
PrimExpr condition;
for (const auto &[buffer, indices] : collector.buffer_indices) {
ICHECK(indices.size() == buffer->shape.size())
<< "indices size mismatch with buffer shape";
for (size_t i = 0; i < indices.size(); ++i) {
auto index = indices[i];
auto bound = analyzer_->const_int_bound(index);
int64_t upper_bound = bound->max_value + 1;
int64_t shape = Downcast<IntImm>(buffer->shape[i])->value;
// Collect the variables that used in the index
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> used_vars;
// post order visit the index
PostOrderVisit(index, [&](const ObjectRef &obj) {
if (const VarNode *v = obj.as<VarNode>()) {
used_vars.insert(GetRef<Var>(v));
}
});
if (used_vars.size() == 0) {
continue;
}
// find related loop vars
Array<Var> related_loop_vars;
for (size_t j = 0; j < loop_vars.size(); ++j) {
auto loop_var = loop_vars[j];
// if find related, pop the loop_vars and loop_extents
if (used_vars.count(loop_var)) {
related_loop_vars.push_back(loop_var);
}
if (related_loop_vars.size() > 1) {
// Only one related loop var is supported transformation currently.
return for_node;
}
auto bound = analyzer_->const_int_bound(index);
int64_t upper_bound = bound->max_value + 1;
int64_t shape = Downcast<IntImm>(buffer->shape[i])->value;
if (upper_bound < shape) {
PrimExpr predicate = LT(index, IntImm(index.dtype(), upper_bound));
condition =
condition.defined() ? And(condition, predicate) : predicate;
}
}
}
}
if (condition.defined()) {
body = IfThenElse(condition, body);
for (int j = loop_vars.size() - 1; j >= 0; --j) {
auto loop_var = loop_vars[j];
auto loop_extent = loop_extents[j];
body = For(loop_var, 0, loop_extent, ForKind::kParallel, body);
}
return Downcast<For>(body);
}
// Only traverse the outer loop
return for_node;
}
// Helper class for collecting buffer access information, only counts fragment
// buffer access
class BufferAccessCollector : public StmtExprVisitor {
public:
void VisitExpr_(const BufferLoadNode *op) final {
if (op->buffer.scope() == "local.fragment") {
if (buffer_indices.find(op->buffer) == buffer_indices.end()) {
buffer_indices[op->buffer] = op->indices;
} else {
// check equal
ICHECK(StructuralEqual()(buffer_indices[op->buffer], op->indices))
<< "indices mismatch for buffer: " << op->buffer;
}
}
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode *op) final {
if (op->buffer.scope() == "local.fragment") {
if (buffer_indices.find(op->buffer) == buffer_indices.end()) {
buffer_indices[op->buffer] = op->indices;
} else {
// check equal
ICHECK(StructuralEqual()(buffer_indices[op->buffer], op->indices))
<< "indices mismatch for buffer: " << op->buffer;
}
}
StmtExprVisitor::VisitStmt_(op);
}
std::unordered_map<Buffer, Array<PrimExpr>, ObjectPtrHash, ObjectPtrEqual>
buffer_indices;
};
};
} // namespace tl
} // namespace tvm
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "arith/ir_mutator_with_analyzer.h" #include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h" #include "arith/ir_visitor_with_analyzer.h"
#include "common/loop_fusion_utils.h" #include "common/loop_fusion_utils.h"
#include "common/loop_parallel_transform_utils.h"
#include "loop_partition.h" #include "loop_partition.h"
#include "loop_vectorize.h" #include "loop_vectorize.h"
#include "runtime/thread_storage_scope.h" #include "runtime/thread_storage_scope.h"
...@@ -501,6 +502,7 @@ private: ...@@ -501,6 +502,7 @@ private:
tvm::transform::Pass LayoutInference() { tvm::transform::Pass LayoutInference() {
using namespace tir::transform; using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
f.CopyOnWrite()->body = ParallelLoopTransformer::Substitute(f->body);
ThreadBindingCollector collector; ThreadBindingCollector collector;
collector(f->body); collector(f->body);
bool has_thread_binding = collector.thread_binding_.size() > 0; bool has_thread_binding = collector.thread_binding_.size() > 0;
......
import tilelang
import tilelang.testing
import tilelang.language as T
import torch
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def tilelang_composable_copy(M, N, block_M, block_N, dtype="float16"):
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M * N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_local = T.alloc_fragment([block_M, block_N], dtype)
B_local = T.alloc_fragment([block_M * block_N], dtype)
T.copy(A[by * block_M, bx * block_N], A_local)
for i, j in T.Parallel(block_M, block_N):
B_local[i * block_N + j] = A_local[i, j]
for i in T.Parallel(block_M * block_N):
B[by * block_M * N + bx * block_N + i // block_N * N + i % block_N] = B_local[i]
return main
def run_tilelang_composable_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"):
program = tilelang_composable_copy(M, N, block_M, block_N, dtype)
kernel = tilelang.compile(
program,
out_idx=[1],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a)
torch.testing.assert_close(b.flatten(), a.flatten(), rtol=1e-2, atol=1e-2)
def test_tilelang_copy():
run_tilelang_composable_copy(M=1024, N=1024, block_M=128, block_N=128)
run_tilelang_composable_copy(M=1024, N=576, block_M=32, block_N=576)
run_tilelang_composable_copy(M=1024, N=576, block_M=32, block_N=576, dtype="float")
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment