"docs/vscode:/vscode.git/clone" did not exist on "6c1fe5c811f82c571af94f787b4721f3a1cc7ca4"
quickstart.py 2.91 KB
Newer Older
1
2
import tilelang
import tilelang.language as T
3

4

5
6
7
8
# @tilelang.jit(target="cuda")
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
9
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
10
    @T.prim_func
11
    def matmul_relu_kernel(
12
13
14
        A: T.Tensor((M, K), dtype),
        B: T.Tensor((K, N), dtype),
        C: T.Tensor((M, N), dtype),
15
    ):
16
        # Initialize Kernel Context
17
18
19
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_K, block_N), dtype)
20
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
21
22
23
24
25
26
27

            # Enable rasterization for better L2 cache locality (Optional)
            # T.use_swizzle(panel_size=10, enable=True)

            # Clear local accumulation
            T.clear(C_local)

28
            for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
29
30
                # Copy tile of A
                # This is a sugar syntax for parallelized copy
31
                T.copy(A[by * block_M, ko * block_K], A_shared)
32

33
34
                # Copy tile of B
                T.copy(B[ko * block_K, bx * block_N], B_shared)
35
36
37
38
39

                # Perform a tile-level GEMM on the shared buffers
                # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
                T.gemm(A_shared, B_shared, C_local)

40
41
42
43
            # relu
            for i, j in T.Parallel(block_M, block_N):
                C_local[i, j] = T.max(C_local[i, j], 0)

44
45
46
            # Copy result back to global memory
            T.copy(C_local, C[by * block_M, bx * block_N])

47
    return matmul_relu_kernel
48
49


50
M = 1024  # M = T.dynamic("m") if you want to use dynamic shape
51
52
53
54
55
56
N = 1024
K = 1024
block_M = 128
block_N = 128
block_K = 32

57
# Define the kernel (matmul) and compile/lower it into an executable module
58
matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K)
59
# Test the kernel in Python with PyTorch data
60
61
62
import torch

# Create random input tensors on the GPU
63
64
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
65
c = torch.empty(M, N, device="cuda", dtype=torch.float16)
66
67

# Run the kernel through the Profiler
68
matmul_relu_kernel(a, b, c)
69

70
print(c)
71
# Reference multiplication using PyTorch
72
ref_c = torch.relu(a @ b)
73
74
75
76
77
78

# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")

# 4. Retrieve and inspect the generated CUDA source (optional)
79
# cuda_source = matmul_relu_kernel.get_kernel_source()
80
# print("Generated CUDA kernel:\n", cuda_source)
81

82
# 5.Profile latency with kernel
83
profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
84
85
86
87

latency = profiler.do_bench()

print(f"Latency: {latency} ms")