example_grouped_gemm_bwd.py 9.62 KB
Newer Older
1
2
3
4
5
6
7
import torch
import math
import argparse
import tilelang
import tilelang.language as T


8
9
@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
def grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype="float16"):
10
11
12
13
14
15
16
17
18
    """
    args:
        a (torch.Tensor): Input tensor of shape (M, K).
        b (torch.Tensor): Input tensor of shape (G, K, N).
    """
    accum_dtype = "float32"

    @T.prim_func
    def kernel(
19
20
21
22
23
24
        A: T.Tensor([batch_sum, K], dtype),  # type: ignore
        B: T.Tensor([batch_count, K, N], dtype),  # type: ignore
        C: T.Tensor([batch_sum, N], dtype),  # type: ignore
        batch_sizes: T.Tensor([batch_count], "int32"),  # type: ignore
        batch_offsets: T.Tensor([batch_count], "int32"),  # type: ignore
        batch_padded_offsets: T.Tensor([batch_count], "int32"),  # type: ignore
25
    ):
26
        with T.Kernel(T.ceildiv(batch_sum, block_M) + batch_count, T.ceildiv(N, block_N), threads=threads) as (bx, by):
27
28
29
30
31
32
33
34
35
            A_shared = T.alloc_shared([block_M, block_K], dtype)
            B_shared = T.alloc_shared([block_K, block_N], dtype)
            C_local = T.alloc_fragment([block_M, block_N], accum_dtype)
            cur_batch_idx = T.alloc_local([1], "int32")
            cur_batch_size = T.alloc_local([1], "int32")

            m_start_padded = bx * block_M

            for i in range(batch_count):
36
                in_cur_batch_idx = m_start_padded >= batch_padded_offsets[i]
37
38
39
                cur_batch_idx[0] = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx[0])

            cur_batch_size[0] = batch_sizes[cur_batch_idx[0]]
40
41
            m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[cur_batch_idx[0]]
            actual_rows = T.max(0, T.min(block_M, cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded))
42
43
44

            T.clear(C_local)
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
45
46
                T.copy(A[m_start : m_start + block_M, k * block_K : (k + 1) * block_K], A_shared)
                T.copy(B[cur_batch_idx[0], k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared)
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
                T.gemm(A_shared, B_shared, C_local)

            for i, j in T.Parallel(block_M, block_N):
                with T.If(i < actual_rows), T.Then():
                    C[m_start + i, by * block_N + j] = C_local[i, j]

    return kernel


class _GroupedGEMM(torch.autograd.Function):
    @staticmethod
    def forward(ctx, a, b, batch_sizes):
        block_M = 64
        block_N = 64
        block_K = 64
        padding_M = block_M
        num_stages = 2
        threads = 128
        batch_sum = a.shape[0]
        batch_count = b.shape[0]
        K = a.shape[1]
        N = b.shape[2]

        assert a.shape[1] == b.shape[1]
        assert batch_sizes.shape[0] == batch_count
        assert batch_sizes.sum() == batch_sum

        batch_offsets_list = [0]
        batch_padded_offsets_list = [0]
        for i in range(batch_count - 1):
            batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes[i])
        for i in range(batch_count - 1):
79
            batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes[i] + 1) / padding_M) * padding_M)
80
        batch_offsets = torch.tensor(batch_offsets_list, device=a.device, dtype=torch.int32)
81
        batch_padded_offsets = torch.tensor(batch_padded_offsets_list, device=a.device, dtype=torch.int32)
82

83
        kernel = grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, num_stages, threads)
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

        o = kernel(a, b, batch_sizes, batch_offsets, batch_padded_offsets)
        ctx.save_for_backward(a, b, batch_sizes, batch_offsets)
        ctx.batch_sum = batch_sum
        ctx.batch_count = batch_count
        ctx.K = K
        return o

    @staticmethod
    def backward(ctx, grad_output):
        block_M = 64
        block_N = 64
        block_K = 64
        num_stages = 2
        threads = 128

        M = ctx.K
        N = grad_output.shape[1]

        A, B, batch_sizes, batch_offsets = ctx.saved_tensors

        def maybe_contiguous(x):
            if x.stride(-1) != 1:
                return x.contiguous()
            return x

        A, B, batch_sizes = [maybe_contiguous(x) for x in (A, B, batch_sizes)]
111
        kernel = grouped_gemm_bwd(ctx.batch_sum, ctx.batch_count, M, N, block_M, block_N, block_K, num_stages, threads)
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146

        dB = kernel(A, grad_output, batch_sizes, batch_offsets)
        return None, dB, None


def ref_program(a, b, batch_sizes):
    assert a.shape[0] == sum(batch_sizes)
    assert b.shape[0] == len(batch_sizes)

    output = torch.empty((sum(batch_sizes), b.shape[2]), device=a.device, dtype=a.dtype)

    start = 0
    a_list = []
    b_list = []
    for i, size in enumerate(batch_sizes):
        end = start + size
        part_a = a[start:end]
        part_b = b[i]
        output[start:end] = torch.mm(part_a, part_b)

        a_list.append(part_a)
        b_list.append(part_b)
        start = end

    return output


def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
    batch_sum = sum(batch_sizes_list)
    batch_count = len(batch_sizes_list)
    batch_offsets_list = [0]
    batch_padded_offsets_list = [0]
    for i in range(batch_count - 1):
        batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes_list[i])
    for i in range(batch_count - 1):
147
        batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes_list[i] + 1) / padding_M) * padding_M)
148
149
150
151
152
153
154
155
156
157
158
159
    A = torch.randn(batch_sum, K, device=device, dtype=dtype)
    B = torch.randn(batch_count, K, M, device=device, dtype=dtype)
    C = torch.empty(batch_sum, M, device=device, dtype=dtype)
    batch_sizes = torch.tensor(batch_sizes_list, device=device, dtype=torch.int32)
    batch_offsets = torch.tensor(batch_offsets_list, device=device, dtype=torch.int32)
    batch_padded_offsets = torch.tensor(batch_padded_offsets_list, device=device, dtype=torch.int32)
    # print(batch_sizes_tensor)
    # print(batch_offsets_tensor)
    # print(batch_padded_offsets_tensor)
    return A, B, C, batch_sizes, batch_offsets, batch_padded_offsets


160
161
@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
def grouped_gemm_bwd(batch_sum, batch_count, M, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype="float16"):
162
163
164
165
166
167
168
169
170
    """
    args:
        a (torch.Tensor): Input tensor of shape (M, K).
        b (torch.Tensor): Input tensor of shape (G, K, N).
    """
    accum_dtype = "float32"

    @T.prim_func
    def kernel(
171
172
173
174
175
        A: T.Tensor([batch_sum, M], dtype),  # type: ignore
        B: T.Tensor([batch_sum, N], dtype),  # type: ignore
        C: T.Tensor([batch_count, M, N], dtype),  # type: ignore
        batch_sizes: T.Tensor([batch_count], "int32"),  # type: ignore
        batch_offsets: T.Tensor([batch_count], "int32"),  # type: ignore
176
    ):
177
        with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), batch_count, threads=threads) as (bx, by, bz):
178
179
180
181
182
183
184
            A_shared = T.alloc_shared([block_K, block_M], dtype)
            B_shared = T.alloc_shared([block_K, block_N], dtype)
            C_local = T.alloc_fragment([block_M, block_N], accum_dtype)

            T.clear(C_local)
            for k in T.Pipelined(T.ceildiv(batch_sizes[bz], block_K), num_stages=num_stages):
                for i, j in T.Parallel(block_K, block_M):
185
                    A_shared[i, j] = T.if_then_else(i < batch_sizes[bz], A[batch_offsets[bz] + k * block_K + i, bx * block_M + j], 0)
186
                for i, j in T.Parallel(block_K, block_N):
187
                    B_shared[i, j] = T.if_then_else(i < batch_sizes[bz], B[batch_offsets[bz] + k * block_K + i, by * block_N + j], 0)
188
189
190
191
192
193
194
                T.gemm(A_shared, B_shared, C_local, transpose_A=True)

            T.copy(C_local, C[bz, bx * block_M, by * block_N])

    return kernel


195
def run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages=2, threads=128, profile=False):
196
197
198
199
    padding_M = block_M
    device = torch.device("cuda")
    dtype = torch.float16

200
    A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs(batch_sizes_list, K, M, False, padding_M, device, dtype)
201
202
203
204
205
206
207
208
209
210
211
212
213
214

    A.requires_grad_(False)
    B.requires_grad_(True)
    O_ref = ref_program(A, B, batch_sizes)
    dO = torch.randn_like(O_ref)

    O_ref.backward(dO, retain_graph=True)
    dB_ref, B.grad = B.grad.clone(), None

    GroupedGEMM = _GroupedGEMM.apply
    O = GroupedGEMM(A, B, batch_sizes)
    O.backward(dO, retain_graph=True)
    dB, B.grad = B.grad.clone(), None

215
    if torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) and torch.allclose(dB, dB_ref, rtol=1e-2, atol=1e-2):
216
217
218
219
220
221
222
        print("✅ Tilelang and Torch match")
    else:
        print("❌ Tilelang and Torch mismatch")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
223
224
225
226
227
    parser.add_argument("--batch_sizes", type=str, default="64, 128", help="comma-separated batch sizes")
    parser.add_argument("--K", type=int, default=8192, help="reduce dim")
    parser.add_argument("--M", type=int, default=8192, help="output dim")
    parser.add_argument("--trans_b", action="store_true", help="transpose B")
    parser.add_argument("--profile", action="store_true", help="profile")
228
229
230
231
232
233
234
235
236
237
238
    args = parser.parse_args()

    batch_sizes_list = [int(x) for x in args.batch_sizes.split(",")]
    K, M, trans_b = args.K, args.M, args.trans_b

    block_M = 64
    block_N = 128
    block_K = 64
    num_stages = 2
    threads = 256

239
    run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages, threads, profile=args.profile)