from dataclasses import dataclass, field import tilelang.testing import tilelang import tilelang.language as T from typing import Any from itertools import product import torch def _gemm_impl(): @T.macro def gemm_impl( A: T.Tensor[[int, int], Any], B: T.Tensor[[int, int], Any], C: T.Tensor[[int, int], Any], out_dtype: T.dtype, block_M: int, block_N: int, block_K: int, ): dtype = A.dtype M, K = A.shape K, N = B.shape with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype) C_local = T.alloc_fragment((block_M, block_N), out_dtype) T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): T.copy(A[bx * block_M, k * block_K], A_shared) T.copy(B[k * block_K, by * block_N], B_shared) T.gemm(A_shared, B_shared, C_local) T.copy(C_local, C[bx * block_M, by * block_N]) return gemm_impl def test_jit2_gemm_annot(): @tilelang.lazy_jit def gemm( A: T.Tensor[[int, int], Any], B: T.Tensor[[int, int], Any], out_dtype: T.dtype = T.float32, block_M: int = 64, block_N: int = 64, block_K: int = 32, ): M, K = A.shape K, N = B.shape C = T.empty(M, N, dtype=out_dtype) _gemm_impl()(A, B, C, out_dtype, block_M, block_N, block_K) return C prod = product([T.float16, T.float32], [T.float32]) gemm.par_compile([{ 'A': T.Tensor((1024, 1024), dtype=in_dtype), 'B': T.Tensor((1024, 1024), dtype=in_dtype), 'out_dtype': out_dtype } for in_dtype, out_dtype in prod]) for in_dtype, out_dtype in prod: in_dtype = in_dtype.torch() out_dtype = out_dtype.torch() A = torch.randn(1024, 1024, dtype=in_dtype, device='cuda') B = torch.randn(1024, 1024, dtype=in_dtype, device='cuda') C_ref = out_dtype(A @ B) C = gemm(A, B) torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2) def test_jit2_gemm_ptr(): @tilelang.lazy_jit def gemm_ptr( A: T.ptr, B: T.ptr, C: T.ptr, M: int, N: int, K: int, dtype: T.dtype, out_dtype: T.dtype, block_M: int = 64, block_N: int = 64, block_K: int = 32, ): A = T.make_tensor(A, (M, K), dtype) B = T.make_tensor(B, (K, N), dtype) C = T.make_tensor(C, (M, N), out_dtype) _gemm_impl()(A, B, C, out_dtype, block_M, block_N, block_K) prod = product([T.float16, T.float32], [T.float32]) gemm_ptr.par_compile([{ 'A': T.ptr(), 'B': T.ptr(), 'C': T.ptr(), 'M': 1024, 'N': 1024, 'K': 1024, 'dtype': in_dtype, 'out_dtype': out_dtype } for in_dtype, out_dtype in prod]) for in_dtype, out_dtype in prod: in_dtype = in_dtype.torch() out_dtype = out_dtype.torch() A = torch.randn(1024, 1024, dtype=in_dtype, device='cuda') B = torch.randn(1024, 1024, dtype=in_dtype, device='cuda') C_ref = out_dtype(A @ B) C = torch.empty(1024, 1024, dtype=out_dtype, device='cuda') gemm_ptr(A, B, C, 1024, 1024, 1024, in_dtype, out_dtype) torch.testing.assert_close(C, C_ref, atol=1e-2, rtol=1e-2) def test_jit2_annot(): from tilelang.language.v2.annot import Annot, ArgVarTable from tilelang.language.v2.builder import Builder import traceback @dataclass class AnnotTest: annot: Annot promote: Any match_ok: list[Any] = field(default_factory=list) match_ng: list[Any] = field(default_factory=list) tests = [ AnnotTest( annot=T.Tensor[[int, int], T.float32], promote=False, match_ok=[torch.randn(1, 1, dtype=torch.float32), T.Tensor((1, 1), dtype=T.float32)], match_ng=[ torch.randn(1, 1, dtype=torch.float16), T.Tensor(1, dtype=T.float32), T.Tensor((1, 1), dtype=T.float16), ], ), AnnotTest( annot=T.Tensor[[int], Any], promote=False, match_ok=[ torch.randn(12, dtype=torch.float32), torch.randn(12, dtype=torch.float16), T.Tensor((1,), dtype=T.float32), T.Tensor((1,), dtype=T.float16), ], match_ng=[torch.randn((1, 1), dtype=torch.float32), T.Tensor((1, 1), dtype=T.float16)]), AnnotTest( annot=T.Tensor[[int, 1], Any], promote=False, match_ok=[ torch.randn(12, 1, dtype=torch.float32), torch.randn(12, 1, dtype=torch.float16), T.Tensor((12, 1), T.float32), T.Tensor((12, 1), T.float16), ], match_ng=[torch.randn(12, 12, dtype=torch.float32), T.Tensor((12, 12), T.float32)]), AnnotTest( annot=T.Tensor[[T.dyn, 1], Any], promote=False, match_ok=[ torch.randn(12, 1, dtype=torch.float32), torch.randn(12, 1, dtype=torch.float16), T.Tensor((12, 1), T.float32), T.Tensor((12, 1), T.float16), ], match_ng=[torch.randn(12, 12, dtype=torch.float32), T.Tensor((12, 12), T.float32)]), AnnotTest( annot=T.Tensor[[1024, 1024], T.float32], promote=True, ), AnnotTest(annot=T.dyn[int, 'X'], promote=False, match_ok=[1, 2, 3, 4]), AnnotTest(annot=T.dyn, promote=False, match_ok=[1, 2, 3, 4]) ] for test in tests: promote = test.annot.promote() promoted = promote is not None if promoted != test.promote: raise AssertionError( f'Promote mismatch for {test.annot}: expected {test.promote}, got {promoted}') with Builder().prim_func('_test'): for match_ok in test.match_ok: try: vt = ArgVarTable() test.annot.create_prim_func_arg('arg', match_ok, vt) except Exception as e: traceback.print_exc() raise AssertionError( f'Match failed for {test.annot} with value {match_ok}: {e}') from e for match_ng in test.match_ng: try: vt = ArgVarTable() test.annot.create_prim_func_arg('arg', match_ng, vt) raise AssertionError( f'Match unexpectedly succeeded for {test.annot} with value {match_ng}') except Exception: pass def test_jit2_many_annot(): @T.macro def copy_impl(A, B): M, N = A.shape M_, N_ = B.shape assert M == M_, f"M mismatch {M} {M_}" assert N == N_, f"N mismatch {N} {N_}" # assert tuple(A.shape) == tuple(B.shape), f"Invalid tensor shape: {A.shape}, {B.shape}" with T.Kernel(T.ceildiv(M, 128), T.ceildiv(N, 128), threads=128) as (bx, by): T.copy(A[bx * 128:bx * 128 + 128, by * 128:by * 128 + 128], B[bx * 128:bx * 128 + 128, by * 128:by * 128 + 128]) @tilelang.lazy_jit def copy1( A: T.Tensor[[int, int], T.float32], B: T.Tensor[[int, int], T.float32], ): copy_impl(A, B) @tilelang.lazy_jit def copy2( A: T.Tensor[[128, 128], T.float32], B: T.Tensor[[128, 128], T.float32], ): copy_impl(A, B) @tilelang.lazy_jit def copy3( A: T.Tensor[[int, 128], T.float32], B: T.Tensor[[int, 128], T.float32], ): copy_impl(A, B) @tilelang.lazy_jit def copy4( A: T.Tensor[[T.dyn, int], T.float32], B: T.Tensor[[T.dyn, int], T.float32], ): copy_impl(A, B) @tilelang.lazy_jit def copy5( A: T.StridedTensor[[int, int], [int, int], T.float32], B: T.StridedTensor[[int, int], [int, int], T.float32], ): copy_impl(A, B) @tilelang.lazy_jit def copy6( A: T.StridedTensor[[T.dyn, int], [int, int], T.float32], B: T.StridedTensor[[T.dyn, int], [int, int], T.float32], ): copy_impl(A, B) for copy in [copy1, copy2, copy3, copy4]: A = torch.randn(128, 128, device='cuda') B = torch.empty(128, 128, device='cuda') copy(A, B) assert torch.equal(B, A) for copy in [copy5, copy6]: A = torch.randn(128, 2, 128, 2, device='cuda') B = torch.randn(128, 2, 128, 2, device='cuda') copy(A[:, 0, :, 0], B[:, 0, :, 0]) assert torch.equal(A[:, 0, :, 0], B[:, 0, :, 0]) def test_jit2_return(): @T.macro def copy_impl(A): M, N = A.shape B = T.empty(M, N, dtype=A.dtype) M, N = A.shape M_, N_ = B.shape assert M == M_, f"M mismatch {M} {M_}" assert N == N_, f"N mismatch {N} {N_}" # assert tuple(A.shape) == tuple(B.shape), f"Invalid tensor shape: {A.shape}, {B.shape}" with T.Kernel(T.ceildiv(M, 128), T.ceildiv(N, 128), threads=128) as (bx, by): T.copy(A[bx * 128:bx * 128 + 128, by * 128:by * 128 + 128], B[bx * 128:bx * 128 + 128, by * 128:by * 128 + 128]) return B @tilelang.lazy_jit def copy0(A: T.Tensor[[int, int], Any]): return copy_impl(A) @tilelang.lazy_jit def copy1(A: T.Tensor[[int, int], T.float32],): return copy_impl(A) @tilelang.lazy_jit def copy2(A: T.Tensor[[128, 128], T.float32],): return copy_impl(A) @tilelang.lazy_jit def copy3(A: T.Tensor[[int, 128], T.float32],): return copy_impl(A) @tilelang.lazy_jit def copy4(A: T.Tensor[[T.dyn, int], T.float32],): return copy_impl(A) @tilelang.lazy_jit def copy5(A: T.StridedTensor[[int, int], [int, int], T.float32],): return copy_impl(A) @tilelang.lazy_jit def copy6(A: T.StridedTensor[[T.dyn, int], [int, int], T.float32],): return copy_impl(A) for copy in [copy0, copy1, copy2, copy3, copy4]: A = torch.randn(128, 128, device='cuda') B = copy(A) assert torch.equal(B, A) for copy in [copy5, copy6]: A = torch.randn(128, 2, 128, 2, device='cuda') B = copy(A[:, 0, :, 0]) assert torch.equal(A[:, 0, :, 0], B) def test_jit2_deepseek_deepgemm(): @tilelang.lazy_jit def deep_gemm( A: T.Tensor[[int, int], T.float8_e4m3], B: T.Tensor[[int, int], T.float8_e4m3], scales_a: T.Tensor[[int, int], T.float32], scales_b: T.Tensor[[int, int], T.float32], out_dtype: T.dtype = T.bfloat16, accum_dtype: T.dtype = T.float32, block_N: int = 128, block_M: int = 128, block_K: int = 128, ): # A: [M, K] # B: [N, K] # scales_a: [M, K // 128] # scales_b: [N, K // 128] # C: [M, N] group_size = 128 in_dtype = A.dtype M, K = A.shape N, K = B.shape C = T.empty(M, N, dtype=out_dtype) assert out_dtype in [ T.bfloat16, T.float32 ], f"Expect out_dtype to be one of [T.float16, T.float32], got {out_dtype}" assert scales_a.shape == [M, T.ceildiv(K, group_size) ], f"Expect scales_a shape to be f{[M, T.ceildiv(K, group_size)]}" assert scales_b.shape == [N, T.ceildiv(K, group_size) ], f"Expect scales_b shape to be f{[N, T.ceildiv(K, group_size)]}" with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), in_dtype) B_shared = T.alloc_shared((block_N, block_K), in_dtype) C_shared = T.alloc_shared((block_M, block_N), out_dtype) scale_C_shared = T.alloc_shared((block_M,), T.float32) C_local = T.alloc_fragment((block_M, block_K), accum_dtype) C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype) T.use_swizzle(panel_size=10) T.clear(C_local) T.clear(C_local_accum) K_iters = T.ceildiv(K, block_K) for k in T.Pipelined(K_iters, num_stages=4): T.copy(A[by * block_M, k * block_K], A_shared) T.copy(B[bx * block_N, k * block_K], B_shared) Scale_B = scales_b[bx * block_N // group_size, k] for i in T.Parallel(block_M): scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B T.gemm(A_shared, B_shared, C_local, transpose_B=True) for i, j in T.Parallel(block_M, block_N): C_local_accum[i, j] += C_local[i, j] * scale_C_shared[i] T.clear(C_local) T.copy(C_local_accum, C_shared) T.copy(C_shared, C[by * block_M, bx * block_N]) return C # def ceildiv(a, b): # return (a + b - 1) // b # def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype): # # A_scale: (M, K//128) ==> (M//128, K//128, 128) # # B_scale: (N//128, K//128) ==> (N//128, K//128, 128) # # A_fp8: (M, K) # # B_fp8: (N, K) # # out_dtype: float16 or float32 # # return C: (M, N) # M, N, K = A_fp8.shape[0], B_fp8.shape[0], A_fp8.shape[1] # A_scales = A_scale.view(M // 128, 128, K // 128).permute(0, 2, 1) # B_scales = B_scale.repeat_interleave(128, dim=1).view(N // 128, K // 128, 128) # C = torch.zeros(M, N, device="cuda", dtype=out_dtype) # c_acc = torch.zeros(128, 128, device="cuda", dtype=torch.float32) # for i in range(ceildiv(M, 128)): # for j in range(ceildiv(N, 128)): # c_acc.zero_() # for k in range(ceildiv(K, 128)): # c = torch._scaled_mm( # A_fp8[i * 128:(i + 1) * 128, k * 128:(k + 1) * 128], # B_fp8[j * 128:(j + 1) * 128, k * 128:(k + 1) * 128].T, # scale_a=A_scales[i, k].view(128, 1).contiguous(), # scale_b=B_scales[j, k].view(1, 128).contiguous(), # out_dtype=torch.bfloat16) # c_acc += c.to(torch.float32) # C[i * 128:(i + 1) * 128, j * 128:(j + 1) * 128] = c_acc.to(out_dtype) # return C # M, N, K = 1024, 1024, 8192 # A = torch.randn((M, K), dtype=torch.float8_e4m3fn, ) if __name__ == '__main__': tilelang.testing.main()