Unverified Commit 73bf8346 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Rebase pipeline injector from upstream tvm (#687)

* [Enhancement] Introduce software pipeline rewriter and refactor buffer access handling

- Added a new `PipelineOpaqueAccessRewriter` class to manage opaque buffer accesses in the software pipeline.
- Refactored the `PipelineBodyRewriter` to utilize the new rewriter for improved buffer access handling.
- Enhanced the `PipelineRewriter` to support additional fragment information and streamline pipeline construction.
- Updated tests to reflect changes in buffer management and access patterns, ensuring compatibility with the new structure.
- Removed obsolete code related to previous buffer access methods for clarity and maintainability.

* test fix
parent b45e9c45
This diff is collapsed.
import torch
import torch.backends
import tilelang.testing
from tilelang import tvm as tvm
import tilelang.language as T
......
......@@ -9,6 +9,7 @@ def _check(original, transformed):
mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
mod = tl.transform.InjectSoftwarePipeline()(mod)
mod = tl.transform.Simplify()(mod)
print(mod["main"])
tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"),
True)
......@@ -41,30 +42,22 @@ def test_trival_pipeline():
@T.prim_func
def expected(A: T.Tensor((16, 1), "float32"), C: T.Tensor((16, 1), "float32")) -> None:
for tx in T.thread_binding(16, thread="threadIdx.x"):
with T.block():
with T.block(""):
T.reads(A[tx, 0])
T.writes(C[tx, 0])
B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
with T.block():
B = T.alloc_buffer((2, 16, 1), scope="shared")
with T.block(""):
T.reads(A[tx, 0])
T.writes(B[0, tx, 0])
B[0, tx, 0] = A[tx, 0] * T.float32(2)
with T.block():
T.reads(A[tx, 1:1], B[0:2, tx, 0])
T.writes(B[1:1, tx, 0], C[tx, 0:0])
for i in range(0):
B[0, tx, 0] = A[tx, 0] * T.float32(2.0)
with T.block(""):
T.reads(A[tx, i + 1])
T.writes(B[i + 1, tx, 0])
B[i + 1, tx, 0] = A[tx, i + 1] * T.float32(2)
T.reads()
T.writes()
T.evaluate(0)
with T.block(""):
T.reads(B[i, tx, 0])
T.writes(C[tx, i])
C[tx, i] = B[i, tx, 0] + T.float32(1)
with T.block():
T.reads(B[0, tx, 0])
T.writes(C[tx, 0])
C[tx, 0] = B[0, tx, 0] + T.float32(1)
C[tx, 0] = B[0, tx, 0] + T.float32(1.0)
_check(before, expected)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment