example_conv_analyze.py 3.37 KB
Newer Older
1
2
3
import tilelang.language as T
from tilelang.tools import Analyzer
from tilelang.carver.arch import CUDA
4
from tilelang.carver.arch import CDNA
5
from tilelang.layout import make_swizzled_layout
6
import torch
Gabriel Wu's avatar
Gabriel Wu committed
7

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
N = 64
C = 256
H = 512
W = 512
F = 512
K = 3
S = 1
D = 1
P = 1


def check_hopper():
    # if not torch.cuda.is_available():
    #     return None
    # props = torch.cuda.get_device_properties(0)
    # compute_capability = props.major, props.minor
    # return compute_capability == (9, 0)
    return False


def kernel(N,
           C,
           H,
           W,
           F,
           K,
           S,
           D,
           P,
           block_M,
           block_N,
           block_K,
           num_stages,
           threads,
           dtype="float16",
           accum_dtype="float"):
    KH, KW = K, K
    OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
    OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
    dtype = "float16"
    accum_dtype = "float"
    is_hopper = check_hopper()

    @T.prim_func
52
    def conv(
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
            data: T.Tensor((N, H, W, C), dtype),
            kernel: T.Tensor((KH, KW, C, F), dtype),
            out: T.Tensor((N, OH, OW, F), dtype),
    ):
        with T.Kernel(
                T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
                threads=threads) as (bx, by):
            data_shared = T.alloc_shared((block_M, block_K), dtype)
            kernel_shared = T.alloc_shared((block_K, block_N), dtype)
            out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            out_shared = T.alloc_shared((block_M, block_N), dtype)

            kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
            out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)

            T.annotate_layout({
                out_shared: make_swizzled_layout(out_shared),
                data_shared: make_swizzled_layout(data_shared),
                kernel_shared: make_swizzled_layout(kernel_shared),
            })

            T.clear(out_local)
            for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
                if is_hopper:
                    T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P)
                else:
                    for i, j in T.Parallel(block_M, block_K):
                        k = k_iter * block_K + j
                        m = by * block_M + i
                        access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
                        access_w = m % OW * S + k // C % KW * D - P
                        in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and
                                    (access_w < W))
                        data_shared[i, j] = T.if_then_else(
                            in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
                T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
                T.gemm(data_shared, kernel_shared, out_local)

            T.copy(out_local, out_shared)
            T.copy(out_shared, out_flat[by * block_M, bx * block_N])

94
    return conv
95
96


97
98
def main():
    my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256)
Gabriel Wu's avatar
Gabriel Wu committed
99
    cuda_device = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
100
101
102
103
104
105
106
    result = Analyzer.analysis(my_func, cuda_device)
    print(result)
    print(f"Analyzed FLOPs: {result.total_flops}")


if __name__ == "__main__":
    main()