Commit 46eb4589 authored by Zhengju Tang's avatar Zhengju Tang Committed by LeiWang1999
Browse files

[CI] Add BlocksparseGemm, Dynamic, and Cast examples to CI (#467)



* [Refactor] Enhance TMA barrier validation and support for additional architectures (#463)

* Updated the TMA barrier validation in `inject_tma_barrier.cc` to check for non-empty `barrier_id_to_range_` before raising an error for missing `create_list_of_mbarrier`.
* Refactored architecture checks in `phase.py` to utilize a new constant `SUPPORTED_TMA_ARCHS`, allowing for easier updates and improved readability in the target architecture validation logic.

* [CI] Add BlocksparseGemm, Dynamic, and Cast examples to CI.

* Lint

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent c99b7056
...@@ -23,7 +23,7 @@ parser.add_argument("--sparsity", type=float, default=0.5, help="Sparsity ratio ...@@ -23,7 +23,7 @@ parser.add_argument("--sparsity", type=float, default=0.5, help="Sparsity ratio
parser.add_argument( parser.add_argument(
"--use_autotune", action="store_true", default=False, help="Whether to use autotune") "--use_autotune", action="store_true", default=False, help="Whether to use autotune")
args = parser.parse_args() args, _ = parser.parse_known_args()
M, N, K = args.m, args.n, args.k M, N, K = args.m, args.n, args.k
sparsity = args.sparsity sparsity = args.sparsity
use_autotune = args.use_autotune use_autotune = args.use_autotune
...@@ -154,7 +154,7 @@ def blocksparse_matmul(M, ...@@ -154,7 +154,7 @@ def blocksparse_matmul(M,
block_mask_shape = (M // block_M, N // block_N, K // block_K) block_mask_shape = (M // block_M, N // block_N, K // block_K)
@T.prim_func @T.prim_func
def main( def block_sparse_matmul(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"), BlockMask: T.Tensor(block_mask_shape, "bool"),
...@@ -178,10 +178,10 @@ def blocksparse_matmul(M, ...@@ -178,10 +178,10 @@ def blocksparse_matmul(M,
T.copy(C_local, C_shared) T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N]) T.copy(C_shared, C[by * block_M, bx * block_N])
return main return block_sparse_matmul
if __name__ == "__main__": def main():
# Initialize input matrices A and B on the GPU with half precision # Initialize input matrices A and B on the GPU with half precision
a = torch.randn(M, K).cuda().half() a = torch.randn(M, K).cuda().half()
...@@ -231,3 +231,7 @@ if __name__ == "__main__": ...@@ -231,3 +231,7 @@ if __name__ == "__main__":
except AssertionError as e: except AssertionError as e:
print("❌ Verification FAILED: Results differ significantly.") print("❌ Verification FAILED: Results differ significantly.")
print(e) print(e)
if __name__ == "__main__":
main()
import tilelang.testing
import example_blocksparse_gemm
def test_example_blocksparse_gemm():
example_blocksparse_gemm.main()
if __name__ == "__main__":
tilelang.testing.main()
...@@ -15,9 +15,9 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m): ...@@ -15,9 +15,9 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
fp8_max = 448.0 fp8_max = 448.0
@T.prim_func @T.prim_func
def main(X: T.Tensor((M, N), dtype), batch_sizes: T.Tensor((BG,), "int32"), X_fp8: T.Tensor( def group_per_split_token_cast(X: T.Tensor((M, N), dtype), batch_sizes: T.Tensor(
(BG, M_max, N), "e4m3_float8"), X_amax: T.Tensor((BG, M_max, T.ceildiv(N, group_size)), (BG,), "int32"), X_fp8: T.Tensor((BG, M_max, N), "e4m3_float8"), X_amax: T.Tensor(
accum_dtype)): (BG, M_max, T.ceildiv(N, group_size)), accum_dtype)):
with T.Kernel( with T.Kernel(
T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz): T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz):
row = bx row = bx
...@@ -61,7 +61,7 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m): ...@@ -61,7 +61,7 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
y_q_local_fp8, X_fp8[bg, row * blk_m:(row + 1) * blk_m, y_q_local_fp8, X_fp8[bg, row * blk_m:(row + 1) * blk_m,
row_g_id * group_size:(row_g_id + 1) * group_size]) row_g_id * group_size:(row_g_id + 1) * group_size])
return main return group_per_split_token_cast
def ceil_div(x: int, y: int) -> int: def ceil_div(x: int, y: int) -> int:
...@@ -160,7 +160,7 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \ ...@@ -160,7 +160,7 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \
return x_fp8 return x_fp8
if __name__ == "__main__": def main():
M, N, BG, blk_m = 8192, 8192, 2, 8 M, N, BG, blk_m = 8192, 8192, 2, 8
if dtype == "float": if dtype == "float":
x = torch.randn(M, N, device="cuda", dtype=torch.float32) x = torch.randn(M, N, device="cuda", dtype=torch.float32)
...@@ -184,7 +184,7 @@ if __name__ == "__main__": ...@@ -184,7 +184,7 @@ if __name__ == "__main__":
execution_backend="cython", execution_backend="cython",
pass_configs={"tl.disable_tma_lower": True}) pass_configs={"tl.disable_tma_lower": True})
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) # profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
x_fp8, x_amax = kernel(x, batch_sizes) x_fp8, x_amax = kernel(x, batch_sizes)
x_fp8_ref, x_amax_ref = ref_program(x, batch_sizes) x_fp8_ref, x_amax_ref = ref_program(x, batch_sizes)
...@@ -208,3 +208,7 @@ if __name__ == "__main__": ...@@ -208,3 +208,7 @@ if __name__ == "__main__":
latency = do_bench(run_torch) latency = do_bench(run_torch)
print("Torch: {:.2f} ms".format(latency)) print("Torch: {:.2f} ms".format(latency))
if __name__ == "__main__":
main()
...@@ -14,8 +14,8 @@ def per_token_cast_to_fp8(M, N, blk_m): ...@@ -14,8 +14,8 @@ def per_token_cast_to_fp8(M, N, blk_m):
fp8_max = 448.0 fp8_max = 448.0
@T.prim_func @T.prim_func
def main(X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "e4m3_float8"), X_amax: T.Tensor( def per_token_cast(X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "e4m3_float8"),
(M, T.ceildiv(N, group_size)), dtype)): X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype)):
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (bx, by): with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (bx, by):
row = bx row = bx
row_g_id = by row_g_id = by
...@@ -48,7 +48,7 @@ def per_token_cast_to_fp8(M, N, blk_m): ...@@ -48,7 +48,7 @@ def per_token_cast_to_fp8(M, N, blk_m):
y_q_local_fp8, X_fp8[row * blk_m:(row + 1) * blk_m, y_q_local_fp8, X_fp8[row * blk_m:(row + 1) * blk_m,
row_g_id * group_size:(row_g_id + 1) * group_size]) row_g_id * group_size:(row_g_id + 1) * group_size])
return main return per_token_cast
def ceil_div(x: int, y: int) -> int: def ceil_div(x: int, y: int) -> int:
...@@ -78,7 +78,7 @@ def ref_program(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: ...@@ -78,7 +78,7 @@ def ref_program(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return x_fp8, (x_amax / 448.0).view(m, -1) return x_fp8, (x_amax / 448.0).view(m, -1)
if __name__ == "__main__": def main():
M, N, blk_m = 8192, 8192, 8 M, N, blk_m = 8192, 8192, 8
program = per_token_cast_to_fp8(M, N, blk_m) program = per_token_cast_to_fp8(M, N, blk_m)
kernel = tilelang.compile( kernel = tilelang.compile(
...@@ -120,3 +120,7 @@ if __name__ == "__main__": ...@@ -120,3 +120,7 @@ if __name__ == "__main__":
x_fp8_triton, x_amax_triton = run_triton() x_fp8_triton, x_amax_triton = run_triton()
latency = do_bench(run_triton) latency = do_bench(run_triton)
print("Triton: {:.2f} ms".format(latency)) print("Triton: {:.2f} ms".format(latency))
if __name__ == "__main__":
main()
import tilelang.testing
import example_group_per_split_token_cast_to_fp8
import example_per_token_cast_to_fp8
def test_example_group_per_split_token_cast_to_fp8():
example_group_per_split_token_cast_to_fp8.main()
def test_example_per_token_cast_to_fp8():
example_per_token_cast_to_fp8.main()
if __name__ == "__main__":
tilelang.testing.main()
...@@ -30,7 +30,7 @@ def matmul_dynamic_mnk( ...@@ -30,7 +30,7 @@ def matmul_dynamic_mnk(
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func @T.prim_func
def main( def dynamic_matmul(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
...@@ -52,11 +52,11 @@ def matmul_dynamic_mnk( ...@@ -52,11 +52,11 @@ def matmul_dynamic_mnk(
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N]) T.copy(C_local, C[by * block_M, bx * block_N])
return main return dynamic_matmul
def test_matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, def matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype,
accum_dtype, num_stages, threads): accum_dtype, num_stages, threads):
print( print(
f"M: {M}, N: {N}, K: {K}, block_M: {block_M}, block_N: {block_N}, block_K: {block_K}, trans_A: {trans_A}, trans_B: {trans_B}, in_dtype: {in_dtype}, out_dtype: {out_dtype}, accum_dtype: {accum_dtype}, num_stages: {num_stages}, threads: {threads}" f"M: {M}, N: {N}, K: {K}, block_M: {block_M}, block_N: {block_N}, block_K: {block_K}, trans_A: {trans_A}, trans_B: {trans_B}, in_dtype: {in_dtype}, out_dtype: {out_dtype}, accum_dtype: {accum_dtype}, num_stages: {num_stages}, threads: {threads}"
) )
...@@ -104,6 +104,17 @@ def test_matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in ...@@ -104,6 +104,17 @@ def test_matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in
print(f"Latency: {latency} ms") print(f"Latency: {latency} ms")
def main():
M, N, K = 16384, 16384, 16384
block_M, block_N, block_K = 128, 128, 32
trans_A, trans_B = False, False
in_dtype, out_dtype = "float16", "float16"
accum_dtype = "float32"
num_stages = 3
threads = 128
matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype,
accum_dtype, num_stages, threads)
if __name__ == "__main__": if __name__ == "__main__":
test_matmul_dynamic(16384, 16384, 16384, 128, 128, 32, False, False, "float16", "float16", main()
"float32", 3, 128)
import tilelang.testing
import example_dynamic
def test_example_dynamic():
example_dynamic.main()
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment