"...reference/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "bbf54a8835811f96bd1e4dc4c2669f94be0bf264"
Commit afa74f4e authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Fix a bug for simplifier (#425)

* Update submodule 'tvm' to latest commit f4a8f9b

* lint fix
parent 2fff0eec
Subproject commit b2945254932cffa89922ec7f6e868d726aed0f6a Subproject commit b16c9f298bc37fa502ffdb2ea809c2793e2a0bd6
import tilelang
import tilelang.language as T
import torch
def tilelang_copy(M, N, block_M, block_N, dtype="float16"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = A[by * block_M + i, bx * block_N + j]
return main
def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"):
program = tilelang_copy(M, N, block_M, block_N, dtype)
kernel = tilelang.compile(
program,
out_idx=[1],
target="cuda",
pass_configs={
"tl.disable_warp_specialized": True,
"tl.disable_tma_lower": True
})
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a)
torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2)
def test_tilelang_copy():
run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128)
run_tilelang_copy(M=1024, N=576, block_M=32, block_N=576)
run_tilelang_copy(M=1024, N=576, block_M=32, block_N=576, dtype="float")
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment