"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "d6f4774c1c66a7e72951ab60e2241aff14e5d688"
Unverified Commit 49d5d80e authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Pipeline] Phaseout fragment and double buffer info from pipeline pass (#711)

* Update submodule 'tvm' to commit e11521e6936a827efa334588d29571fbb4620107

* Refactor inject_pipeline.cc to enhance pipeline body rewriting and condition handling

- Introduced a new function to replace IfThenElse nodes with their then_case while preserving attributes.
- Streamlined the PipelineBodyRewriter to improve buffer access rewriting and async state management.
- Enhanced the handling of pipeline loop conditions and added support for predicate conditions in the pipeline body.
- Removed obsolete code and improved overall code clarity and maintainability.

* lint fix

* Refactor return statements in inject_pipeline.cc to remove unnecessary std::move calls

- Updated return statements in multiple methods to return objects directly instead of using std::move, improving code clarity and potentially avoiding unnecessary moves.
- Ensured consistent handling of BufferStore and BufferLoad nodes during pipeline transformations.

* test fix
parent 64bd0651
This diff is collapsed.
...@@ -9,7 +9,6 @@ def _check(original, transformed): ...@@ -9,7 +9,6 @@ def _check(original, transformed):
mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
mod = tl.transform.InjectSoftwarePipeline()(mod) mod = tl.transform.InjectSoftwarePipeline()(mod)
mod = tl.transform.Simplify()(mod) mod = tl.transform.Simplify()(mod)
print(mod["main"])
tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"),
True) True)
...@@ -40,21 +39,29 @@ def test_trival_pipeline(): ...@@ -40,21 +39,29 @@ def test_trival_pipeline():
C[tx, i] = B[tx, 0] + T.float32(1) C[tx, i] = B[tx, 0] + T.float32(1)
@T.prim_func @T.prim_func
def expected(A: T.Tensor((16, 1), "float32"), C: T.Tensor((16, 1), "float32")) -> None: def expected(A: T.Buffer((16, 1), "float32"), C: T.Buffer((16, 1), "float32")):
for tx in T.thread_binding(16, thread="threadIdx.x"): for tx in T.thread_binding(16, thread="threadIdx.x"):
with T.block(""): with T.block():
T.reads(A[tx, 0]) T.reads(A[tx, 0])
T.writes(C[tx, 0]) T.writes(C[tx, 0])
B = T.alloc_buffer((2, 16, 1), scope="shared") B = T.alloc_buffer((2, 16, 1), scope="shared")
with T.block(""): with T.block():
T.reads(A[tx, 0]) T.reads(A[tx, 0])
T.writes(B[0, tx, 0]) T.writes(B[0, tx, 0])
B[0, tx, 0] = A[tx, 0] * T.float32(2.0) B[0, tx, 0] = A[tx, 0] * T.float32(2.0)
with T.block(""): with T.block():
T.reads() T.reads(A[tx, 1:1], B[0:2, tx, 0])
T.writes() T.writes(B[1:1, tx, 0], C[tx, 0:0])
T.evaluate(0) for i in range(0):
with T.block(""): with T.block():
T.reads(A[tx, i + 1])
T.writes(B[i + 1, tx, 0])
B[i + 1, tx, 0] = A[tx, i + 1] * T.float32(2.0)
with T.block():
T.reads(B[i, tx, 0])
T.writes(C[tx, i])
C[tx, i] = B[i, tx, 0] + T.float32(1.0)
with T.block():
T.reads(B[0, tx, 0]) T.reads(B[0, tx, 0])
T.writes(C[tx, 0]) T.writes(C[tx, 0])
C[tx, 0] = B[0, tx, 0] + T.float32(1.0) C[tx, 0] = B[0, tx, 0] + T.float32(1.0)
......
...@@ -79,7 +79,6 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: ...@@ -79,7 +79,6 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.LegalizeVectorizedLoop()(mod) mod = tilelang.transform.LegalizeVectorizedLoop()(mod)
# Add safety checks for memory accesses # Add safety checks for memory accesses
mod = tilelang.transform.LegalizeSafeMemoryAccess()(mod) mod = tilelang.transform.LegalizeSafeMemoryAccess()(mod)
# Align dynamic shared memory allocations
# Simplify again to clean up any duplicated conditions # Simplify again to clean up any duplicated conditions
# that may have been introduced by safety checks # that may have been introduced by safety checks
# use an enhanced pass to simplify the dynamic symbolics # use an enhanced pass to simplify the dynamic symbolics
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment