example_elementwise_add.py 2.8 KB
Newer Older
1
2
3
import argparse
import itertools
import torch
4
5
import tilelang
import tilelang.language as T
6
from tilelang.autotuner import AutoTuner
7
8


9
10
11
12
def ref_program(x, y):
    return x + y


13
@tilelang.jit(out_idx=[-1])
14
def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
15
    @T.prim_func
16
    def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N), out_dtype)):
17
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
18
19
20
21
22
23
24
            A_shared = T.alloc_shared((block_M, block_N), in_dtype)
            B_shared = T.alloc_shared((block_M, block_N), in_dtype)
            C_local = T.alloc_fragment((block_M, block_N), out_dtype)
            C_shared = T.alloc_shared((block_M, block_N), out_dtype)

            T.copy(A[by * block_M, bx * block_N], A_shared)
            T.copy(B[by * block_M, bx * block_N], B_shared)
25
            for local_y, local_x in T.Parallel(block_M, block_N):
26
27
28
                C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x]
            T.copy(C_local, C_shared)
            T.copy(C_shared, C[by * block_M, bx * block_N])
29

30
    return elem_add
31
32


33
34
35
36
37
38
39
40
41
42
43
44
def get_configs(M, N):
    block_M = [64, 128, 256]
    block_N = [64, 128, 256]
    threads = [64, 128, 256]
    configs = list(itertools.product(block_M, block_N, threads))
    return [{"block_M": bm, "block_N": bn, "threads": th} for bm, bn, th in configs]


def get_best_config(M, N):
    def kernel(block_M=None, block_N=None, threads=None):
        return elementwise_add(M, N, block_M, block_N, "float32", "float32", threads)

45
46
47
    autotuner = (
        AutoTuner.from_kernel(kernel=kernel, configs=get_configs(M, N))
        .set_compile_args(
48
            out_idx=[-1],
49
            target="cuda",
50
51
        )
        .set_profile_args(
52
53
54
55
            supply_type=tilelang.TensorSupplyType.Auto,
            ref_prog=ref_program,
            skip_check=False,
        )
56
    )
57
    return autotuner.run(warmup=3, rep=20)
58

59

60
def main():
61
    parser = argparse.ArgumentParser()
62
    parser.add_argument("--m", type=int, default=1024)
63
64
    parser.add_argument("--n", type=int, default=1024)
    parser.add_argument("--use_autotune", action="store_true", default=False)
65
    args, _ = parser.parse_known_args()
66
67
68
69
70
71
72
73
74
75
    M, N = args.m, args.n

    a = torch.randn(M, N, dtype=torch.float32, device="cuda")
    b = torch.randn(M, N, dtype=torch.float32, device="cuda")

    if args.use_autotune:
        result = get_best_config(M, N)
        kernel = result.kernel
    else:
        # Default config
76
        config = {"block_M": 32, "block_N": 32, "threads": 128}
77
        kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32")
78

79
80
    out = kernel(a, b)
    torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)
81
82
83
84


if __name__ == "__main__":
    main()