"...git@developer.sourcefind.cn:yangql/composable_kernel.git" did not exist on "5bf0475afdf33e823ed8c0247b3a8711326601d3"
Commit f944b79e authored by Cunxiao Ni's avatar Cunxiao Ni Committed by GitHub
Browse files

[CI][Test] Add test cases for element_add (#47)

* [CI][Test] Add test cases for element_add

* [Doc] fix typo

* Parallelization

* format

* remove useless condition

* format
parent bedab1a0
...@@ -3,25 +3,26 @@ ...@@ -3,25 +3,26 @@
import tilelang import tilelang
import tilelang.language as T import tilelang.language as T
# `make_mma_swizzle_layout` is a python defined layout function # `make_mma_swizzle_layout` is a python defined layout function
# specifically designed for for MMA operations # specifically designed for MMA operations
# which ensures the consistency with the nvidia CUTLASS Library. # which ensures the consistency with the nvidia CUTLASS Library.
# to avoid bank conflicts and maximize the performance. # to avoid bank conflicts and maximize the performance.
from tilelang.intrinsics import ( from tilelang.intrinsics import (
make_mma_swizzle_layout as make_swizzle_layout,) make_mma_swizzle_layout as make_swizzle_layout,) # noqa: F401
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer((M, K), dtype), A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype), B: T.Buffer((K, N), dtype),
C: T.Buffer((M, N), dtype), C: T.Buffer((M, N), dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Apply layout optimizations or define your own layout (Optional) # Apply layout optimizations or define your own layout (Optional)
# If not specified, we will deduce the layout automatically # If not specified, we will deduce the layout automatically
...@@ -71,7 +72,6 @@ import torch ...@@ -71,7 +72,6 @@ import torch
a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16) a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16) b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler # Run the kernel through the Profiler
c = jit_kernel(a, b) c = jit_kernel(a, b)
...@@ -86,7 +86,7 @@ print("Kernel output matches PyTorch reference.") ...@@ -86,7 +86,7 @@ print("Kernel output matches PyTorch reference.")
cuda_source = jit_kernel.get_kernel_source() cuda_source = jit_kernel.get_kernel_source()
print("Generated CUDA kernel:\n", cuda_source) print("Generated CUDA kernel:\n", cuda_source)
# 5.Pofile latency with kernel # 5.Profile latency with kernel
profiler = jit_kernel.get_profiler() profiler = jit_kernel.get_profiler()
latency = profiler.do_bench() latency = profiler.do_bench()
......
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
import torch
def elementwise_add(
M,
N,
block_M,
block_N,
in_dtype,
out_dtype,
threads,
):
import tilelang.language as T
@T.prim_func
def main(
A: T.Buffer((M, N), in_dtype),
B: T.Buffer((M, N), in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
start_x = bx * block_N
start_y = by * block_M
for (local_y, local_x) in T.Parallel(block_M, block_N):
y = start_y + local_y
x = start_x + local_x
C[y, x] = A[y, x] + B[y, x]
return main
def run_elementwise_add(
M,
N,
in_dtype,
out_dtype,
block_M,
block_N,
num_threads=128,
):
program = elementwise_add(
M,
N,
block_M,
block_N,
in_dtype,
out_dtype,
num_threads,
)
mod, params = tl.lower(program)
mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer)
def ref_program(A, B):
C = torch.add(A, B)
C = C.to(torch.__getattribute__(out_dtype))
return C
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_elementwise_add_f32():
run_elementwise_add(
512,
1024,
"float32",
"float32",
128,
256,
)
def test_elementwise_add_f16():
run_elementwise_add(
512,
1024,
"float16",
"float16",
128,
256,
)
def test_elementwise_add_i32():
run_elementwise_add(
512,
1024,
"int32",
"int32",
128,
256,
)
def test_elementwise_add_f32f16():
run_elementwise_add(
512,
1024,
"float32",
"float16",
128,
256,
)
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment