quickstart.py 2.92 KB
Newer Older
1
2
import tilelang
import tilelang.language as T
3

4

5
6
7
8
# @tilelang.jit(target="cuda")
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
9
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
10

11
    @T.prim_func
12
    def matmul_relu_kernel(
13
14
15
            A: T.Tensor((M, K), dtype),
            B: T.Tensor((K, N), dtype),
            C: T.Tensor((M, N), dtype),
16
    ):
17
        # Initialize Kernel Context
18
19
20
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_K, block_N), dtype)
21
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
22
23
24
25
26
27
28

            # Enable rasterization for better L2 cache locality (Optional)
            # T.use_swizzle(panel_size=10, enable=True)

            # Clear local accumulation
            T.clear(C_local)

29
            for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
30
31
                # Copy tile of A
                # This is a sugar syntax for parallelized copy
32
                T.copy(A[by * block_M, ko * block_K], A_shared)
33

34
35
                # Copy tile of B
                T.copy(B[ko * block_K, bx * block_N], B_shared)
36
37
38
39
40

                # Perform a tile-level GEMM on the shared buffers
                # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
                T.gemm(A_shared, B_shared, C_local)

41
42
43
44
            # relu
            for i, j in T.Parallel(block_M, block_N):
                C_local[i, j] = T.max(C_local[i, j], 0)

45
46
47
            # Copy result back to global memory
            T.copy(C_local, C[by * block_M, bx * block_N])

48
    return matmul_relu_kernel
49
50


51
52
53
54
55
56
57
M = 1024  # M = T.symbolic("m") if you want to use dynamic shape
N = 1024
K = 1024
block_M = 128
block_N = 128
block_K = 32

58
# 1. Define the kernel (matmul) and compile/lower it into an executable module
59
matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K)
60
61
62
63
64

# 3. Test the kernel in Python with PyTorch data
import torch

# Create random input tensors on the GPU
65
66
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
67
c = torch.empty(M, N, device="cuda", dtype=torch.float16)
68
69

# Run the kernel through the Profiler
70
matmul_relu_kernel(a, b, c)
71

72
print(c)
73
# Reference multiplication using PyTorch
74
ref_c = torch.relu(a @ b)
75
76
77
78
79
80

# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")

# 4. Retrieve and inspect the generated CUDA source (optional)
81
82
# cuda_source = jit_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)
83

84
# 5.Profile latency with kernel
85
profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
86
87
88
89

latency = profiler.do_bench()

print(f"Latency: {latency} ms")