"...composable_kernel_onnxruntime.git" did not exist on "39d92e7dfdb2893a0e7d0521523c442ec403712c"
example_elementwise_add.py 2.4 KB
Newer Older
1
2
3
import argparse
import itertools
import torch
4
5
import tilelang
import tilelang.language as T
6
from tilelang.autotuner import AutoTuner
7
8


9
10
11
12
def ref_program(x, y):
    return x + y


13
@tilelang.jit(out_idx=[-1])
14
def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
15
16

    @T.prim_func
17
18
    def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor(
        (M, N), out_dtype)):
19
20
21
22
23
24
25
26
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
            start_x = bx * block_N
            start_y = by * block_M
            for (local_y, local_x) in T.Parallel(block_M, block_N):
                y = start_y + local_y
                x = start_x + local_x
                C[y, x] = A[y, x] + B[y, x]

27
    return elem_add
28
29


30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def get_configs(M, N):
    block_M = [64, 128, 256]
    block_N = [64, 128, 256]
    threads = [64, 128, 256]
    configs = list(itertools.product(block_M, block_N, threads))
    return [{"block_M": bm, "block_N": bn, "threads": th} for bm, bn, th in configs]


def get_best_config(M, N):

    def kernel(block_M=None, block_N=None, threads=None):
        return elementwise_add(M, N, block_M, block_N, "float32", "float32", threads)

    autotuner = AutoTuner.from_kernel(
        kernel=kernel, configs=get_configs(M, N)).set_compile_args(
            out_idx=[-1],
46
47
            target="cuda",
        ).set_profile_args(
48
49
50
51
52
            supply_type=tilelang.TensorSupplyType.Auto,
            ref_prog=ref_program,
            skip_check=False,
        )
    return autotuner.run(warmup=3, rep=20)
53

54

55
def main():
56
57
58
59
    parser = argparse.ArgumentParser()
    parser.add_argument("--m", type=int, default=512)
    parser.add_argument("--n", type=int, default=1024)
    parser.add_argument("--use_autotune", action="store_true", default=False)
60
    args, _ = parser.parse_known_args()
61
62
63
64
65
66
67
68
69
70
71
    M, N = args.m, args.n

    a = torch.randn(M, N, dtype=torch.float32, device="cuda")
    b = torch.randn(M, N, dtype=torch.float32, device="cuda")

    if args.use_autotune:
        result = get_best_config(M, N)
        kernel = result.kernel
    else:
        # Default config
        config = {"block_M": 128, "block_N": 256, "threads": 128}
72
        kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32")
73
74
75

    out = kernel(a, b)
    torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)
76
77
78
79


if __name__ == "__main__":
    main()