example_tilelang_gemm_streamk.py 5.86 KB
Newer Older
root's avatar
init  
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import torch
import torch.backends
import tilelang
from tilelang import language as T
import math


def cdiv(a, b):
    return math.ceil(a / b)


# disable tf32
torch.backends.cuda.matmul.allow_tf32 = False

m = 256
n = 1024
k = 512

total_sm = 108

torch.random.manual_seed(0)
# uniform distribution from -1 to 1
A = torch.rand(m, k, device="cuda", dtype=torch.float16) * 2 - 1
B = torch.rand(n, k, device="cuda", dtype=torch.float16) * 2 - 1

streamk_programs = total_sm
BLOCK_SIZE_M = 16
BLOCK_SIZE_N = 128
BLOCK_SIZE_K = 32
two_tiles = False
M, K = A.shape
N, K = B.shape
# accumulator types
# compute grid (work to do per SM on the first wave)
num_block_m = tilelang.cdiv(M, BLOCK_SIZE_M)
num_block_n = tilelang.cdiv(N, BLOCK_SIZE_N)
iters_per_tile = tilelang.cdiv(K, BLOCK_SIZE_K)
total_tiles = num_block_m * num_block_n

# Two-tile SK + DP
streamk_tiles = total_tiles % streamk_programs
if (total_tiles - streamk_tiles > streamk_programs):  # (total_tiles // total_programs > 1)
    streamk_tiles += streamk_programs

blocking_tiles = total_tiles - streamk_tiles
streamk_iters = streamk_tiles * iters_per_tile

streamk_full_tiles = streamk_iters // streamk_programs
streamk_partial_tiles = streamk_iters % streamk_programs

print(f"{total_tiles=} ")
print(f"{iters_per_tile=} ")

sm_patition_factor = max(blocking_tiles // total_sm, 1)


@tilelang.jit
def tl_matmul_streamk(
    M,
    N,
    K,
    streamk_tiles,
    block_M,
    block_N,
    block_K,
    trans_A,
    trans_B,
    dtypeAB,
    dtypeC,
    accum_dtype,
    num_stages,
    threads,
):
    assert not trans_A
    A_shape = (M, K) if not trans_A else (K, M)
    B_shape = (K, N) if not trans_B else (N, K)
    A_shared_shape = (block_M, block_K) if not trans_A else (block_K, block_M)
    B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K)

    @T.macro
    def compute_first_wave(
        pid: T.int32,
        A_buf: T.Tensor,
        A_buf_shared: T.SharedBuffer,
        B_buf: T.Tensor,
        B_buf_shared: T.SharedBuffer,
        C: T.Tensor,
        C_local: T.LocalBuffer,
    ):
        start_iter = T.alloc_fragment((1,), "int32", "local")
        end_iter = T.alloc_fragment((1,), "int32", "local")

        start_iter[0] = pid * streamk_full_tiles + T.min(pid, streamk_partial_tiles)
        last_iter = (pid + 1) * streamk_full_tiles + T.min(pid + 1, streamk_partial_tiles)

        while start_iter[0] < last_iter:
            end_iter[0] = T.min(
                start_iter[0] + (iters_per_tile - (start_iter[0] % iters_per_tile)),
                last_iter,
            )

            tile_id = start_iter[0] // iters_per_tile
            remain_iters = start_iter[0] % iters_per_tile
            pid_m = tile_id // T.ceildiv(N, block_N)
            pid_n = tile_id % T.ceildiv(N, block_N)

            T.clear(C_local)
            for k in T.Pipelined(end_iter[0] - start_iter[0], num_stages=num_stages):
                T.copy(
                    A_buf[pid_m * block_M, (k + (start_iter[0] % iters_per_tile)) * block_K],
                    A_buf_shared,
                )
                T.copy(
                    B_buf[pid_n * block_N, (k + (start_iter[0] % iters_per_tile)) * block_K],
                    B_buf_shared,
                )
                T.gemm(A_buf_shared, B_buf_shared, C_local, transpose_B=trans_B)

            # last iteration of the tile always happens before its start on another SM
            if remain_iters == 0 and (end_iter[0] % iters_per_tile == 0):
                T.copy(C_local, C[pid_m * block_M, pid_n * block_N])
            else:
                for i, j in T.Parallel(block_M, block_N):
                    T.atomic_add(C[pid_m * block_M + i, pid_n * block_N + j], C_local[i, j])

            start_iter[0] = end_iter[0]

    @T.macro
    def compute_full_tiles(
        pid: T.int32,
        A_buf: T.Tensor,
        A_shared: T.SharedBuffer,
        B_buf: T.Tensor,
        B_shared: T.SharedBuffer,
        C: T.Tensor,
        C_local: T.LocalBuffer,
    ):

        for p in T.serial(sm_patition_factor):
            tile_id = pid + streamk_tiles + p * total_sm
            pid_m = tile_id // T.ceildiv(N, block_N)
            pid_n = tile_id % T.ceildiv(N, block_N)
            T.clear(C_local)

            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
                T.copy(A_buf[pid_m * block_M, k * block_K], A_shared)
                T.copy(B_buf[pid_n * block_N, k * block_K], B_shared)
                T.gemm(A_shared, B_shared, C_local, transpose_B=trans_B)
            T.copy(C_local, C[pid_m * block_M, pid_n * block_N])

    @T.prim_func
    def main(
            A: T.Tensor(A_shape, dtypeAB),
            B: T.Tensor(B_shape, dtypeAB),
            C: T.Tensor((M, N), dtypeC),
    ):
        with T.Kernel(streamk_programs, threads=threads) as pid:

            A_shared = T.alloc_shared(A_shared_shape, dtypeAB)
            B_shared = T.alloc_shared(B_shared_shape, dtypeAB)
            A_shared_full_tiles = T.alloc_shared(A_shared_shape, dtypeAB)
            B_shared_full_tiles = T.alloc_shared(B_shared_shape, dtypeAB)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

            compute_first_wave(pid, A, A_shared, B, B_shared, C, C_local)

            if sm_patition_factor > 0:
                compute_full_tiles(pid, A, A_shared_full_tiles, B, B_shared_full_tiles, C, C_local)

    return main


def main():
    kernel = tl_matmul_streamk(
        m,
        n,
        k,
        streamk_tiles,
        BLOCK_SIZE_M,
        BLOCK_SIZE_N,
        BLOCK_SIZE_K,
        False,
        True,
        "float16",
        "float16",
        "float32",
        2,
        64,
    )

    print(kernel.get_kernel_source())

    b_c = torch.zeros((m, n), device="cuda", dtype=torch.float16)

    kernel(A, B, b_c)

    C = torch.matmul(A, B.T)

    print(b_c)
    print(C)
    torch.testing.assert_close(C, b_c, rtol=1e-2, atol=1e-2)


if __name__ == "__main__":
    main()