Commit 46eb4589 authored by Zhengju Tang's avatar Zhengju Tang Committed by LeiWang1999
Browse files

[CI] Add BlocksparseGemm, Dynamic, and Cast examples to CI (#467)



* [Refactor] Enhance TMA barrier validation and support for additional architectures (#463)

* Updated the TMA barrier validation in `inject_tma_barrier.cc` to check for non-empty `barrier_id_to_range_` before raising an error for missing `create_list_of_mbarrier`.
* Refactored architecture checks in `phase.py` to utilize a new constant `SUPPORTED_TMA_ARCHS`, allowing for easier updates and improved readability in the target architecture validation logic.

* [CI] Add BlocksparseGemm, Dynamic, and Cast examples to CI.

* Lint

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent c99b7056
......@@ -23,7 +23,7 @@ parser.add_argument("--sparsity", type=float, default=0.5, help="Sparsity ratio
parser.add_argument(
"--use_autotune", action="store_true", default=False, help="Whether to use autotune")
args = parser.parse_args()
args, _ = parser.parse_known_args()
M, N, K = args.m, args.n, args.k
sparsity = args.sparsity
use_autotune = args.use_autotune
......@@ -154,7 +154,7 @@ def blocksparse_matmul(M,
block_mask_shape = (M // block_M, N // block_N, K // block_K)
@T.prim_func
def main(
def block_sparse_matmul(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"),
......@@ -178,10 +178,10 @@ def blocksparse_matmul(M,
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
return block_sparse_matmul
if __name__ == "__main__":
def main():
# Initialize input matrices A and B on the GPU with half precision
a = torch.randn(M, K).cuda().half()
......@@ -231,3 +231,7 @@ if __name__ == "__main__":
except AssertionError as e:
print("❌ Verification FAILED: Results differ significantly.")
print(e)
if __name__ == "__main__":
main()
import tilelang.testing
import example_blocksparse_gemm
def test_example_blocksparse_gemm():
example_blocksparse_gemm.main()
if __name__ == "__main__":
tilelang.testing.main()
......@@ -15,9 +15,9 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
fp8_max = 448.0
@T.prim_func
def main(X: T.Tensor((M, N), dtype), batch_sizes: T.Tensor((BG,), "int32"), X_fp8: T.Tensor(
(BG, M_max, N), "e4m3_float8"), X_amax: T.Tensor((BG, M_max, T.ceildiv(N, group_size)),
accum_dtype)):
def group_per_split_token_cast(X: T.Tensor((M, N), dtype), batch_sizes: T.Tensor(
(BG,), "int32"), X_fp8: T.Tensor((BG, M_max, N), "e4m3_float8"), X_amax: T.Tensor(
(BG, M_max, T.ceildiv(N, group_size)), accum_dtype)):
with T.Kernel(
T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz):
row = bx
......@@ -61,7 +61,7 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
y_q_local_fp8, X_fp8[bg, row * blk_m:(row + 1) * blk_m,
row_g_id * group_size:(row_g_id + 1) * group_size])
return main
return group_per_split_token_cast
def ceil_div(x: int, y: int) -> int:
......@@ -160,7 +160,7 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \
return x_fp8
if __name__ == "__main__":
def main():
M, N, BG, blk_m = 8192, 8192, 2, 8
if dtype == "float":
x = torch.randn(M, N, device="cuda", dtype=torch.float32)
......@@ -184,7 +184,7 @@ if __name__ == "__main__":
execution_backend="cython",
pass_configs={"tl.disable_tma_lower": True})
print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
# profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
x_fp8, x_amax = kernel(x, batch_sizes)
x_fp8_ref, x_amax_ref = ref_program(x, batch_sizes)
......@@ -208,3 +208,7 @@ if __name__ == "__main__":
latency = do_bench(run_torch)
print("Torch: {:.2f} ms".format(latency))
if __name__ == "__main__":
main()
......@@ -14,8 +14,8 @@ def per_token_cast_to_fp8(M, N, blk_m):
fp8_max = 448.0
@T.prim_func
def main(X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "e4m3_float8"), X_amax: T.Tensor(
(M, T.ceildiv(N, group_size)), dtype)):
def per_token_cast(X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "e4m3_float8"),
X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype)):
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (bx, by):
row = bx
row_g_id = by
......@@ -48,7 +48,7 @@ def per_token_cast_to_fp8(M, N, blk_m):
y_q_local_fp8, X_fp8[row * blk_m:(row + 1) * blk_m,
row_g_id * group_size:(row_g_id + 1) * group_size])
return main
return per_token_cast
def ceil_div(x: int, y: int) -> int:
......@@ -78,7 +78,7 @@ def ref_program(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return x_fp8, (x_amax / 448.0).view(m, -1)
if __name__ == "__main__":
def main():
M, N, blk_m = 8192, 8192, 8
program = per_token_cast_to_fp8(M, N, blk_m)
kernel = tilelang.compile(
......@@ -120,3 +120,7 @@ if __name__ == "__main__":
x_fp8_triton, x_amax_triton = run_triton()
latency = do_bench(run_triton)
print("Triton: {:.2f} ms".format(latency))
if __name__ == "__main__":
main()
import tilelang.testing
import example_group_per_split_token_cast_to_fp8
import example_per_token_cast_to_fp8
def test_example_group_per_split_token_cast_to_fp8():
example_group_per_split_token_cast_to_fp8.main()
def test_example_per_token_cast_to_fp8():
example_per_token_cast_to_fp8.main()
if __name__ == "__main__":
tilelang.testing.main()
......@@ -30,7 +30,7 @@ def matmul_dynamic_mnk(
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(
def dynamic_matmul(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
......@@ -52,11 +52,11 @@ def matmul_dynamic_mnk(
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
return dynamic_matmul
def test_matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype,
accum_dtype, num_stages, threads):
def matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype,
accum_dtype, num_stages, threads):
print(
f"M: {M}, N: {N}, K: {K}, block_M: {block_M}, block_N: {block_N}, block_K: {block_K}, trans_A: {trans_A}, trans_B: {trans_B}, in_dtype: {in_dtype}, out_dtype: {out_dtype}, accum_dtype: {accum_dtype}, num_stages: {num_stages}, threads: {threads}"
)
......@@ -104,6 +104,17 @@ def test_matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in
print(f"Latency: {latency} ms")
def main():
M, N, K = 16384, 16384, 16384
block_M, block_N, block_K = 128, 128, 32
trans_A, trans_B = False, False
in_dtype, out_dtype = "float16", "float16"
accum_dtype = "float32"
num_stages = 3
threads = 128
matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype,
accum_dtype, num_stages, threads)
if __name__ == "__main__":
test_matmul_dynamic(16384, 16384, 16384, 128, 128, 32, False, False, "float16", "float16",
"float32", 3, 128)
main()
import tilelang.testing
import example_dynamic
def test_example_dynamic():
example_dynamic.main()
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment