"git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "e4d17fe89c4758f21b450965d15ec08ab4698fac"
Unverified Commit bb8b3cd7 authored by Tong WU's avatar Tong WU Committed by GitHub
Browse files

[Enhancement] Update async intrinsic handling in inject_fence_proxy (#1068)



* [Enhancement] Update async intrinsic handling in inject_fence_proxy

* Added support for wgmma async intrinsics in IsAsyncIntrinsic function.
* Changed handling of unknown externs to treat them as Generic instead of Async, improving accuracy in proxy kind determination.

* test fix

* Update testing/python/transform/test_tilelang_transform_inject_fence_proxy.py
Co-authored-by: default avatarcoderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
Co-authored-by: default avatarcoderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
parent f8d3e73e
......@@ -94,6 +94,11 @@ bool IsAsyncIntrinsic(const CallNode *call) {
return true;
}
// wgmma async intrinsics
if (call->op.same_as(tl_gemm()) || call->op.same_as(tl_gemm_sp())) {
return true;
}
return false;
}
......@@ -208,8 +213,10 @@ private:
} else if (IsKnownGeneric(call)) {
kind = ProxyKind::kGeneric;
} else {
// Treat unknown externs as async to avoid missing required fences.
kind = ProxyKind::kAsync;
// We can now treat extern as Generic, since gemm and gemm_sp are never
// represented as call_extern nodes. They are call_intrin nodes and will
// be handled by IsAsyncIntrinsic above.
kind = ProxyKind::kGeneric;
}
}
......
......@@ -31,7 +31,8 @@ def test_lower_fence_proxy():
C_local = T.decl_buffer((32,), scope="local")
for i in T.unroll(16):
C_local[i * 2:i * 2 + 2] = T.Broadcast(T.float32(0), 2)
T.call_extern("handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.call_intrin("handle", tir.op.Op.get("tl.tl_gemm"),
"tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3))
......@@ -45,7 +46,8 @@ def test_lower_fence_proxy():
for i in T.unroll(16):
C_local[i * 2:i * 2 + 2] = T.Broadcast(T.float32(0), 2)
T.fence_proxy_async()
T.call_extern("handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.call_intrin("handle", tir.op.Op.get("tl.tl_gemm"),
"tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3))
......@@ -169,7 +171,6 @@ def test_wgmma_marked_async():
mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main"))
mod = tvm.tir.transform.BindTarget(auto_target)(mod)
mod = tl.transform.InjectFenceProxy()(mod)
order = []
def visit(node):
......@@ -185,43 +186,5 @@ def test_wgmma_marked_async():
assert order.index("tl.fence_proxy_async") < order.index("tl.ptx_wgmma_ss")
def test_wgmma_after_descriptor():
@T.prim_func
def before():
with T.Kernel(1):
desc_a = T.decl_buffer((1,), "uint64", scope="local.descriptor")
desc_b = T.decl_buffer((1,), "uint64", scope="local.descriptor")
C_local = T.decl_buffer((32,), "float16", scope="local")
T.initialize_descriptor(desc_a, T.uint64(0), 2, 1, 32)
T.initialize_descriptor(desc_b, T.uint64(0), 2, 1, 32)
T.warpgroup_arrive()
T.ptx_wgmma_ss("float16", "m64n64k16", T.bool(True), T.bool(True), "fp16", "fp16",
"fp16", desc_a.data, T.int32(0), desc_b.data, T.int32(0), C_local.data,
T.int32(0), T.bool(True), 1, 1)
mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main"))
mod = tvm.tir.transform.BindTarget(auto_target)(mod)
mod = tl.transform.InjectFenceProxy()(mod)
fence_count = 0
order = []
def visit(node):
nonlocal fence_count
if isinstance(node, tir.Evaluate):
call = node.value
if isinstance(call, tir.Call):
name = getattr(call.op, "name", "")
order.append(name)
if name == "tl.fence_proxy_async":
fence_count += 1
tir.stmt_functor.post_order_visit(mod["main"].body, visit)
assert fence_count >= 1
assert "tl.warpgroup_arrive" in order
assert order.index("tl.fence_proxy_async") < order.index("tl.warpgroup_arrive")
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment