Unverified Commit 73bf8346 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Rebase pipeline injector from upstream tvm (#687)

* [Enhancement] Introduce software pipeline rewriter and refactor buffer access handling

- Added a new `PipelineOpaqueAccessRewriter` class to manage opaque buffer accesses in the software pipeline.
- Refactored the `PipelineBodyRewriter` to utilize the new rewriter for improved buffer access handling.
- Enhanced the `PipelineRewriter` to support additional fragment information and streamline pipeline construction.
- Updated tests to reflect changes in buffer management and access patterns, ensuring compatibility with the new structure.
- Removed obsolete code related to previous buffer access methods for clarity and maintainability.

* test fix
parent b45e9c45
This diff is collapsed.
import torch import torch
import torch.backends
import tilelang.testing import tilelang.testing
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.language as T import tilelang.language as T
......
...@@ -9,6 +9,7 @@ def _check(original, transformed): ...@@ -9,6 +9,7 @@ def _check(original, transformed):
mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
mod = tl.transform.InjectSoftwarePipeline()(mod) mod = tl.transform.InjectSoftwarePipeline()(mod)
mod = tl.transform.Simplify()(mod) mod = tl.transform.Simplify()(mod)
print(mod["main"])
tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"),
True) True)
...@@ -41,30 +42,22 @@ def test_trival_pipeline(): ...@@ -41,30 +42,22 @@ def test_trival_pipeline():
@T.prim_func @T.prim_func
def expected(A: T.Tensor((16, 1), "float32"), C: T.Tensor((16, 1), "float32")) -> None: def expected(A: T.Tensor((16, 1), "float32"), C: T.Tensor((16, 1), "float32")) -> None:
for tx in T.thread_binding(16, thread="threadIdx.x"): for tx in T.thread_binding(16, thread="threadIdx.x"):
with T.block(): with T.block(""):
T.reads(A[tx, 0]) T.reads(A[tx, 0])
T.writes(C[tx, 0]) T.writes(C[tx, 0])
B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") B = T.alloc_buffer((2, 16, 1), scope="shared")
with T.block(): with T.block(""):
T.reads(A[tx, 0]) T.reads(A[tx, 0])
T.writes(B[0, tx, 0]) T.writes(B[0, tx, 0])
B[0, tx, 0] = A[tx, 0] * T.float32(2) B[0, tx, 0] = A[tx, 0] * T.float32(2.0)
with T.block(): with T.block(""):
T.reads(A[tx, 1:1], B[0:2, tx, 0]) T.reads()
T.writes(B[1:1, tx, 0], C[tx, 0:0]) T.writes()
for i in range(0): T.evaluate(0)
with T.block(""): with T.block(""):
T.reads(A[tx, i + 1])
T.writes(B[i + 1, tx, 0])
B[i + 1, tx, 0] = A[tx, i + 1] * T.float32(2)
with T.block(""):
T.reads(B[i, tx, 0])
T.writes(C[tx, i])
C[tx, i] = B[i, tx, 0] + T.float32(1)
with T.block():
T.reads(B[0, tx, 0]) T.reads(B[0, tx, 0])
T.writes(C[tx, 0]) T.writes(C[tx, 0])
C[tx, 0] = B[0, tx, 0] + T.float32(1) C[tx, 0] = B[0, tx, 0] + T.float32(1.0)
_check(before, expected) _check(before, expected)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment