You need to sign in or sign up before continuing.
Commit 7d4156df authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[CI][Test] Add test cases for tilelang transform LowerHopperIntrin (#59)

* [Dev] Add FlashDecoding example

* [CI][Test] Add test cases for tilelang kernel convolution

* [CI][Test] Add test cases for tilelang kernel FlashAttention

* Reduce the number of stages to ensure the shared memory allocation is valid

* Temporarily remove the dim128 case

* lint

* update einops in requirements-dev.txt

* update einops in requirements-test.txt

* remove einops in requirements-dev.txt

* [CI][Test] Add test cases for tilelang transform ClusterPlanning

* [CI][Test] Add test cases for tilelang transform LowerHopperIntrin
parent 5e259239
...@@ -39,7 +39,7 @@ const Op &CreateTMAIm2ColDescriptorOp(); ...@@ -39,7 +39,7 @@ const Op &CreateTMAIm2ColDescriptorOp();
/*! /*!
* \brief Create a list of mbarrier with num_threads * \brief Create a list of mbarrier with num_threads
* *
* GetMBarrier(num_threads0, num_threads1, ...) * CreateListofMBarrierOp(num_threads0, num_threads1, ...)
* *
*/ */
const Op &CreateListofMBarrierOp(); const Op &CreateListofMBarrierOp();
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang as tl
from tilelang.utils.target import determine_target
import tilelang.language as T
import tilelang.testing
from tvm import tir
auto_target = tvm.target.Target(determine_target("auto"))
def _check(original, transformed):
func = original
mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
mod = tvm.tir.transform.BindTarget(auto_target)(mod)
mod = tl.transform.LowerHopperIntrin()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod)
transformed = tvm.IRModule.from_expr(transformed.with_attr("global_symbol", "main"))
transformed = tvm.tir.transform.BindTarget(auto_target)(transformed)
transformed = tir.transform.LowerOpaqueBlock()(transformed)
tvm.ir.assert_structural_equal(mod["main"], transformed["main"], True)
def test_lower_hopper_intrin_barrier():
@T.prim_func
def before():
with T.Kernel(8):
_ = T.launch_thread("threadIdx.x", 128)
T.CreateListofMBarrierOp(128, 128, 128, 128)
@T.prim_func
def after():
with T.Kernel(8):
v_1 = T.launch_thread("threadIdx.x", 128)
T.evaluate(tir.Call("handle", "tir.create_barriers", [4]))
with T.If(v_1 == 0), T.Then():
T.evaluate(
tir.Call("handle", "tir.ptx_init_barrier_thread_count",
[T.GetMBarrierOp(0), 128]))
T.evaluate(
tir.Call("handle", "tir.ptx_init_barrier_thread_count",
[T.GetMBarrierOp(1), 128]))
T.evaluate(
tir.Call("handle", "tir.ptx_init_barrier_thread_count",
[T.GetMBarrierOp(2), 128]))
T.evaluate(
tir.Call("handle", "tir.ptx_init_barrier_thread_count",
[T.GetMBarrierOp(3), 128]))
T.evaluate(tir.Call("handle", "tir.tvm_storage_sync", ["shared"]))
_check(before, after)
if __name__ == "__main__":
tilelang.testing.main()
test_lower_hopper_intrin_barrier()
...@@ -30,6 +30,7 @@ from .customize import ( ...@@ -30,6 +30,7 @@ from .customize import (
atomic_addx2, # noqa: F401 atomic_addx2, # noqa: F401
dp4a, # noqa: F401 dp4a, # noqa: F401
) )
from .builtin import * # noqa: F401
def use_swizzle(panel_size: int, order: str = "row", enable: bool = True): def use_swizzle(panel_size: int, order: str = "row", enable: bool = True):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The language interface for tl programs."""
from tvm import tir
def CreateListofMBarrierOp(*args):
return tir.call_intrin("handle", tir.op.Op.get("tl.CreateListofMBarrierOp"), *args)
def GetMBarrierOp(*args):
return tir.call_intrin("handle", tir.op.Op.get("tl.GetMBarrierOp"), *args)
def CreateTMADescriptorOp(*args):
return tir.call_intrin("handle", tir.op.Op.get("tl.CreateTMADescriptorOp"), *args)
def TMALoadOp(*args):
return tir.call_intrin("handle", tir.op.Op.get("tl.TMALoadOp"), *args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment