quickstart.py 3.76 KB
Newer Older
1
2
3
import tilelang
import tilelang.language as T
# `make_mma_swizzle_layout` is a python defined layout function
4
# specifically designed for MMA operations
5
6
7
# which ensures the consistency with the nvidia CUTLASS Library.
# to avoid bank conflicts and maximize the performance.
from tilelang.intrinsics import (
8
9
    make_mma_swizzle_layout as make_swizzle_layout,)  # noqa: F401

10

11
12
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
13
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
14

15
16
    @T.prim_func
    def main(
17
18
19
            A: T.Tensor((M, K), dtype),
            B: T.Tensor((K, N), dtype),
            C: T.Tensor((M, N), dtype),
20
    ):
21
        # Initialize Kernel Context
22
23
24
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_K, block_N), dtype)
25
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
26
27
28
29
30
31
32
33
34
35
36
37
38
39

            # Apply layout optimizations or define your own layout (Optional)
            # If not specified, we will deduce the layout automatically
            # T.annotate_layout({
            #     A_shared: make_swizzle_layout(A_shared),
            #     B_shared: make_swizzle_layout(B_shared),
            # })

            # Enable rasterization for better L2 cache locality (Optional)
            # T.use_swizzle(panel_size=10, enable=True)

            # Clear local accumulation
            T.clear(C_local)

40
            for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
41
42
                # Copy tile of A
                # This is a sugar syntax for parallelized copy
43
44
                # for i, k in T.Parallel(M, block_K):
                #     A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
45
                T.copy(A[by * block_M, ko * block_K], A_shared)
46

47
48
                # Copy tile of B
                T.copy(B[ko * block_K, bx * block_N], B_shared)
49
50
51
52
53
54
55
56
57
58
59

                # Perform a tile-level GEMM on the shared buffers
                # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
                T.gemm(A_shared, B_shared, C_local)

            # Copy result back to global memory
            T.copy(C_local, C[by * block_M, bx * block_N])

    return main


60
61
62
63
64
65
66
M = 1024  # M = T.symbolic("m") if you want to use dynamic shape
N = 1024
K = 1024
block_M = 128
block_N = 128
block_K = 32

67
# 1. Define the kernel (matmul) and compile/lower it into an executable module
68
func = matmul(M, N, K, block_M, block_N, block_K)
69
70
71
72
73

# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
74
75
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", execution_backend="cython")
# jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", execution_backend="dlpack")
76
77
78
79
80

# 3. Test the kernel in Python with PyTorch data
import torch

# Create random input tensors on the GPU
81
82
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
83
84
85
86

# Run the kernel through the Profiler
c = jit_kernel(a, b)

87
print(c)
88
89
90
91
92
93
94
95
# Reference multiplication using PyTorch
ref_c = a @ b

# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")

# 4. Retrieve and inspect the generated CUDA source (optional)
96
97
# cuda_source = jit_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)
98

99
# 5.Profile latency with kernel
100
profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
101
102
103
104

latency = profiler.do_bench()

print(f"Latency: {latency} ms")