example_conv_analyze.py 3.4 KB
Newer Older
1
2
3
import tilelang.language as T
from tilelang.tools import Analyzer
from tilelang.carver.arch import CUDA
4
from tilelang.carver.arch import CDNA
5
from tilelang.layout import make_swizzled_layout
6
import torch
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
N = 64
C = 256
H = 512
W = 512
F = 512
K = 3
S = 1
D = 1
P = 1


def check_hopper():
    # if not torch.cuda.is_available():
    #     return None
    # props = torch.cuda.get_device_properties(0)
    # compute_capability = props.major, props.minor
    # return compute_capability == (9, 0)
    return False


def kernel(N,
           C,
           H,
           W,
           F,
           K,
           S,
           D,
           P,
           block_M,
           block_N,
           block_K,
           num_stages,
           threads,
           dtype="float16",
           accum_dtype="float"):
    KH, KW = K, K
    OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
    OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
    dtype = "float16"
    accum_dtype = "float"
    is_hopper = check_hopper()

    @T.prim_func
51
    def conv(
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
            data: T.Tensor((N, H, W, C), dtype),
            kernel: T.Tensor((KH, KW, C, F), dtype),
            out: T.Tensor((N, OH, OW, F), dtype),
    ):
        with T.Kernel(
                T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
                threads=threads) as (bx, by):
            data_shared = T.alloc_shared((block_M, block_K), dtype)
            kernel_shared = T.alloc_shared((block_K, block_N), dtype)
            out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            out_shared = T.alloc_shared((block_M, block_N), dtype)

            kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
            out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)

            T.annotate_layout({
                out_shared: make_swizzled_layout(out_shared),
                data_shared: make_swizzled_layout(data_shared),
                kernel_shared: make_swizzled_layout(kernel_shared),
            })

            T.clear(out_local)
            for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
                if is_hopper:
                    T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P)
                else:
                    for i, j in T.Parallel(block_M, block_K):
                        k = k_iter * block_K + j
                        m = by * block_M + i
                        access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
                        access_w = m % OW * S + k // C % KW * D - P
                        in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and
                                    (access_w < W))
                        data_shared[i, j] = T.if_then_else(
                            in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
                T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
                T.gemm(data_shared, kernel_shared, out_local)

            T.copy(out_local, out_shared)
            T.copy(out_shared, out_flat[by * block_M, bx * block_N])

93
    return conv
94
95


96
97
def main():
    my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256)
98
99
100
101
    if torch.version.hip is not None:
        cuda_device=CDNA("hip")
    else:
        cuda_device = CUDA("cuda")
102
103
104
105
106
107
108
    result = Analyzer.analysis(my_func, cuda_device)
    print(result)
    print(f"Analyzed FLOPs: {result.total_flops}")


if __name__ == "__main__":
    main()