benchmark_mamba_chunk_scan.py 16.3 KB
Newer Older
1
2
3
4
5
6
7
import argparse
import torch
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
from einops import rearrange, repeat
import itertools
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import math
from tilelang.profiler import do_bench

try:
    from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd
except ImportError as err:
    raise ImportError("Please install mamba-ssm to use the triton chunk scan operator.") from err

try:
    import helion
    from helion._testing import run_example
    import helion.language as hl
except ImportError as err:
    raise ImportError("Please install helion to use the helion chunk scan operator.") from err
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53


def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
    """
    Argument:
        cb: (batch, nchunks, ngroups, chunk_size, chunk_size)
        x: (batch, seqlen, nheads, headdim)
        dt: (batch, nheads, nchunks, chunk_size)
        dA_cumsum: (batch, nheads, nchunks, chunk_size)
        C: (batch, seqlen, ngroups, dstate)
        prev_states: (batch, nchunks, nheads, headdim, dstate)
        D: (nheads, headdim) or (nheads,)
        z: (batch, seqlen, nheads, headdim)
    Return:
        out: (batch, seqlen, nheads, headdim)
    """
    _, _, ngroups, _, _ = cb.shape
    batch, seqlen, nheads, headdim = x.shape
    # _, _, ngroups, dstate = B.shape
    # assert B.shape == (batch, seqlen, ngroups, dstate)
    _, _, nchunks, chunk_size = dt.shape
    assert seqlen == nchunks * chunk_size
    # assert C.shape == B.shape
    # B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
    C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups)
    cb = repeat(cb, "b c g l s -> b c (g h) l s", h=nheads // ngroups)
    # CB = torch.einsum("bclhn,bcshn->bchls", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks),
    #                   rearrange(B, "b (c s) h n -> b c s h n", c=nchunks))
    # (batch, nheads, nchunks, chunksize, chunksize)
    dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :]
    decay = torch.exp(dt_segment_sum)
    scores_decay = cb * rearrange(decay, "b h c l s -> b c h l s")
54
    causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0)
55
    scores_decay = scores_decay.masked_fill(~causal_mask, 0)
56
57
58
    out = torch.einsum(
        "bchls,bhcs,bcshp->bclhp", scores_decay.to(x.dtype), dt.to(x.dtype), rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)
    )
59
    state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1"))
60
61
62
    out_prev = (
        torch.einsum("bclhn,bchpn->bclhp", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out
    )
63
64
65
66
67
68
69
70
71
    out = out + out_prev
    out = rearrange(out, "b c l h p -> b (c l) h p")
    if D is not None:
        if D.dim() == 1:
            D = rearrange(D, "h -> h 1")
        out = out + x * D
    return out


72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D):
    out, _ = _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D)
    return out


def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
    @helion.kernel()
    def helion_mamba2_chunk_scan_kernel(
        cb: torch.Tensor,
        x: torch.Tensor,
        dt: torch.Tensor,
        dA_cumsum: torch.Tensor,
        C: torch.Tensor,
        prev_states: torch.Tensor,
        D: torch.Tensor,
    ) -> torch.Tensor:
        """
        Argument:
            cb: (batch, nchunks, ngroups, chunk_size, chunk_size)
            x: (batch, seqlen, nheads, headdim)
            dt: (batch, nheads, nchunks, chunk_size)
            dA_cumsum: (batch, nheads, nchunks, chunk_size)
            C: (batch, seqlen, ngroups, dstate)
            prev_states: (batch, nchunks, nheads, headdim, dstate)
            D: (nheads,)
        Return:
            out: (batch, seqlen, nheads, headdim)
        """

        batch, nchunks, ngroups, chunk_size, _ = cb.shape
        _, seqlen, nheads, headdim = x.shape
        _, _, _, dstate = C.shape
        assert nchunks == (seqlen + chunk_size - 1) // chunk_size

        block_m = hl.register_block_size(chunk_size)
        block_n = hl.register_block_size(headdim)
        block_k = hl.register_block_size(64, 64)
        dstate = hl.specialize(dstate)

        assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
        assert x.shape == (batch, seqlen, nheads, headdim)
        assert dt.shape == (batch, nheads, nchunks, chunk_size)
        assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
        assert C.shape == (batch, seqlen, ngroups, dstate)
        assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate)
        assert D.shape == (nheads,)

        dtype = cb.dtype
        accum_dtype = torch.float32
121
        assert x.dtype == dt.dtype == dA_cumsum.dtype == C.dtype == prev_states.dtype == D.dtype == dtype
122
123
124
125
126
127
128

        out = torch.empty_like(x)

        p = 1.44269504

        for tile_h, tile_m, tile_n, tile_b, tile_c in hl.tile(
            [nheads, chunk_size, headdim, batch, nchunks],
129
            block_size=[1, block_m, block_n, 1, 1],
130
131
        ):
            acc_o = hl.zeros([tile_m, tile_n], dtype=accum_dtype)
132
            dA_cumsum_local_m = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, tile_m].to(torch.float32)
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
            scale_m_local = torch.exp2(dA_cumsum_local_m * p)

            C_local = C[
                tile_b.begin,
                tile_m.index + tile_c.begin * chunk_size,
                tile_h.begin // (nheads // ngroups),
                :,
            ]
            prev_states_local = prev_states[tile_b.begin, tile_c.begin, tile_h.begin, tile_n, :]
            acc_o = hl.dot(C_local, prev_states_local.T, acc=acc_o)
            acc_o *= scale_m_local[:, None]

            for tile_k in hl.tile((tile_m.id + 1) * block_m, block_size=block_k):
                cb_local = cb[
                    tile_b.begin,
                    tile_c.begin,
                    tile_h.begin // (nheads // ngroups),
                    tile_m,
                    tile_k,
                ]
153
154
                dA_cumsum_local_k = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, tile_k].to(torch.float32)
                cb_local *= torch.exp2(dA_cumsum_local_m[:, None] * p - dA_cumsum_local_k[None, :] * p)
155
156
157
158
159
160
161
162
163
164
165
166
167
                dt_local = dt[tile_b.begin, tile_h.begin, tile_c.begin, tile_k].to(torch.float32)
                cb_local = (cb_local * dt_local[None, :]).to(dtype)
                pred = (tile_m.index + 0)[:, None] >= (tile_k.index + 0)[None, :]
                cb_local = torch.where(pred, cb_local, torch.zeros_like(cb_local))
                x_local = x[
                    tile_b.begin,
                    tile_c.begin * chunk_size + tile_k.index,
                    tile_h.begin,
                    tile_n,
                ]
                acc_o = hl.dot(cb_local, x_local, acc=acc_o)

            D_local = D[tile_h.begin].to(torch.float32)
168
            x_residual = x[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, tile_n].to(torch.float32)
169
            acc_o += x_residual * D_local
170
            out[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, tile_n] = acc_o.to(dtype=dtype)
171
172
173
174
175
176
177

        return out

    args = (cb, x, dt, dA_cumsum, C, states, D)
    run_example(helion_mamba2_chunk_scan_kernel, ref_program, args)


178
def get_configs():
179
    iter_params = dict(block_M=[64, 128, 256], block_N=[32, 64], block_K=[64, 128, 256], block_Dstate=[128], num_stages=[1, 2, 3, 4, 5])
180
181
182
183
184
185
186
187
188
189
    return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]


@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(
    out_idx=[7],
    pass_configs={
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
    },
)
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def chunk_scan_fwd(
    batch,
    seqlen,
    chunk_size,
    ngroups,
    nheads,
    headdim,
    dstate,
    block_M=64,
    block_N=64,
    block_K=64,
    block_Dstate=128,
    num_stages=2,
    threads=128,
):
205
206
207
208
209
210
211
    dtype = "float16"
    accum_dtype = "float"
    nchunks = T.ceildiv(seqlen, chunk_size)
    p = 1.44269504

    @T.prim_func
    def main(
212
213
214
215
216
217
218
219
        cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype),  # type: ignore
        x: T.Tensor((batch, seqlen, nheads, headdim), dtype),  # type: ignore
        dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype),  # type: ignore
        dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype),  # type: ignore
        C: T.Tensor((batch, seqlen, ngroups, dstate), dtype),  # type: ignore
        prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype),  # type: ignore
        D: T.Tensor((nheads), dtype),  # type: ignore
        Output: T.Tensor((batch, seqlen, nheads, headdim), dtype),  # type: ignore
220
    ):
221
222
223
224
225
        with T.Kernel(nheads, T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), batch * nchunks, threads=threads) as (
            bz,
            bx,
            by,
        ):
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
            acc_o = T.alloc_fragment((block_M, block_N), accum_dtype)
            acc_o_shared = T.alloc_shared((block_M, block_N), dtype)
            cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn")
            cb_local = T.alloc_fragment((block_M, block_K), dtype)
            dA_cs_k_shared = T.alloc_shared((block_K), dtype, scope="shared")
            dA_cs_k_local = T.alloc_fragment((block_K), accum_dtype)
            dA_cs_m_local = T.alloc_fragment((block_M), accum_dtype)
            dt_shared = T.alloc_shared((block_K), dtype, scope="shared")
            dt_local = T.alloc_fragment((block_K), accum_dtype)
            x_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared.dyn")
            dA_cs_m_shared = T.alloc_shared((block_M), dtype, scope="shared")
            scale_m_local = T.alloc_fragment((block_M), accum_dtype)
            C_shared = T.alloc_shared((block_M, block_Dstate), dtype)
            prev_state_shared = T.alloc_shared((block_N, block_Dstate), dtype)
            D_local = T.alloc_fragment((1), accum_dtype)
            x_residual_shared = T.alloc_shared((block_M, block_N), dtype, scope="shared.dyn")
            x_residual_local = T.alloc_fragment((block_M, block_N), accum_dtype)

            batch_idx = by % batch
            chunk_idx = by // batch
            # m: chunk_size
            # n : headdim
            m_idx = bx // T.ceildiv(headdim, block_N)
            n_idx = bx % T.ceildiv(headdim, block_N)

251
252
253
254
255
256
257
            T.annotate_layout(
                {
                    acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared),
                    cb_shared: tilelang.layout.make_swizzled_layout(cb_shared),
                    x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared),
                }
            )
258
259
260

            T.no_set_max_nreg()

261
            T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M : (m_idx + 1) * block_M], dA_cs_m_shared)
262
263
264
265
266
267
            T.copy(dA_cs_m_shared, dA_cs_m_local)
            T.clear(acc_o)

            for i in T.Parallel(block_M):
                scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p)
            T.copy(
268
269
270
271
272
273
274
275
276
                C[
                    batch_idx,
                    chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M,
                    bz // (nheads // ngroups),
                    0:block_Dstate,
                ],
                C_shared,
            )
            T.copy(prev_states[batch_idx, chunk_idx, bz, n_idx * block_N : (n_idx + 1) * block_N, 0:block_Dstate], prev_state_shared)
277
278
279
280
281
282
283
284
            T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True)
            for i, j in T.Parallel(block_M, block_N):
                acc_o[i, j] *= scale_m_local[i]

            loop_range = T.ceildiv((m_idx + 1) * block_M, block_K)

            for k in T.Pipelined(loop_range, num_stages=num_stages):
                T.copy(
285
286
287
288
289
290
291
292
293
                    cb[
                        batch_idx,
                        chunk_idx,
                        bz // (nheads // ngroups),
                        m_idx * block_M : (m_idx + 1) * block_M,
                        k * block_K : (k + 1) * block_K,
                    ],
                    cb_shared,
                )
294
                T.copy(cb_shared, cb_local)
295
                T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dA_cs_k_shared)
296
297
                T.copy(dA_cs_k_shared, dA_cs_k_local)
                for i, j in T.Parallel(block_M, block_K):
298
299
                    cb_local[i, j] = cb_local[i, j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p)
                T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared)
300
301
302
303
                T.copy(dt_shared, dt_local)
                for i, j in T.Parallel(block_M, block_K):
                    cb_local[i, j] *= dt_local[j]
                for i, j in T.Parallel(block_M, block_K):
304
                    cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, cb_local[i, j], 0)
305
                T.copy(
306
307
308
309
310
311
312
313
                    x[
                        batch_idx,
                        chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K,
                        bz,
                        n_idx * block_N : (n_idx + 1) * block_N,
                    ],
                    x_shared,
                )
314
315
316
317
                T.gemm(cb_local, x_shared, acc_o)

            D_local[0] = D[bz]
            T.copy(
318
319
320
321
322
323
324
325
                x[
                    batch_idx,
                    chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M,
                    bz,
                    n_idx * block_N : (n_idx + 1) * block_N,
                ],
                x_residual_shared,
            )
326
327
328
329
330
331
332
            T.copy(x_residual_shared, x_residual_local)
            for i, j in T.Parallel(block_M, block_N):
                acc_o[i, j] += x_residual_local[i, j] * D_local[0]

            T.copy(acc_o, acc_o_shared)
            T.copy(
                acc_o_shared,
333
334
335
336
337
338
339
                Output[
                    batch_idx,
                    chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M,
                    bz,
                    n_idx * block_N : (n_idx + 1) * block_N,
                ],
            )
340
341
342
343
344
345

    return main


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
346
347
348
349
350
351
352
353
    parser.add_argument("--batch", type=int, default=8, help="batch size")
    parser.add_argument("--heads", type=int, default=80, help="heads")
    parser.add_argument("--groups", type=int, default=1, help="groups")
    parser.add_argument("--seq_len", type=int, default=4096, help="sequence length")
    parser.add_argument("--chunk_size", type=int, default=256, help="chunk size")
    parser.add_argument("--dim", type=int, default=64, help="dim")
    parser.add_argument("--dstate", type=int, default=128, help="dstate")
    parser.add_argument("--tune", action="store_true", help="tune configs")
354
    args = parser.parse_args()
355
356
357
358
359
360
361
362
363
    batch, heads, groups, seq_len, chunk_size, dim, dstate = (
        args.batch,
        args.heads,
        args.groups,
        args.seq_len,
        args.chunk_size,
        args.dim,
        args.dstate,
    )
364
    nchunks = math.ceil(seq_len / chunk_size)
365
366
    total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate

367
    print("Benchmarking TileLang...")
368
369
370
371
372
373
374
    kernel = chunk_scan_fwd(batch, seq_len, chunk_size, groups, heads, dim, dstate)
    best_latency = kernel.latency
    best_config = kernel.config
    ref_latency = kernel.ref_latency
    print(f"Best latency: {best_latency}")
    print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
    print(f"Best config: {best_config}")
375
376
377
378
379
380
381
382
383
384

    cb = torch.randn(batch, nchunks, groups, chunk_size, chunk_size).half().cuda()
    x = torch.randn(batch, seq_len, heads, dim).half().cuda()
    dt = torch.randn(batch, heads, nchunks, chunk_size).half().cuda()
    dA_cumsum = torch.randn(batch, heads, nchunks, chunk_size).half().cuda()
    C = torch.randn(batch, seq_len, groups, dstate).half().cuda()
    states = torch.randn(batch, nchunks, heads, dim, dstate).half().cuda()
    D = torch.randn(heads).half().cuda()

    print("Benchmarking Triton...")
385
    triton_latency = do_bench(lambda: chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D), _n_warmup=10, _n_repeat=10)
386
387
388
389
    print(f"Triton TFlops: {total_flops / triton_latency * 1e-9}")

    print("Benchmarking Helion...")
    chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D)