example_mamba_chunk_scan.py 12.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import argparse
import torch
import tilelang
from tilelang import Profiler
from tilelang.autotuner import *
import tilelang.language as T
from einops import rearrange, repeat
import itertools


def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D):
    from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd
    out, _ = _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D)
    return out


def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
    """
    Argument:
        cb: (batch, nchunks, ngroups, chunk_size, chunk_size)
        x: (batch, seqlen, nheads, headdim)
        dt: (batch, nheads, nchunks, chunk_size)
        dA_cumsum: (batch, nheads, nchunks, chunk_size)
        C: (batch, seqlen, ngroups, dstate)
        prev_states: (batch, nchunks, nheads, headdim, dstate)
        D: (nheads, headdim) or (nheads,)
        z: (batch, seqlen, nheads, headdim)
    Return:
        out: (batch, seqlen, nheads, headdim)
    """
    _, _, ngroups, _, _ = cb.shape
    batch, seqlen, nheads, headdim = x.shape
    # _, _, ngroups, dstate = B.shape
    # assert B.shape == (batch, seqlen, ngroups, dstate)
    _, _, nchunks, chunk_size = dt.shape
    assert seqlen == nchunks * chunk_size
    # assert C.shape == B.shape
    # B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
    C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups)
    cb = repeat(cb, "b c g l s -> b c (g h) l s", h=nheads // ngroups)
    # CB = torch.einsum("bclhn,bcshn->bchls", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks),
    #                   rearrange(B, "b (c s) h n -> b c s h n", c=nchunks))
    # (batch, nheads, nchunks, chunksize, chunksize)
    dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :]
    decay = torch.exp(dt_segment_sum)
    scores_decay = cb * rearrange(decay, "b h c l s -> b c h l s")
    causal_mask = torch.tril(
        torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0)
    scores_decay = scores_decay.masked_fill(~causal_mask, 0)
    out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype),
                       rearrange(x, "b (c s) h p -> b c s h p", c=nchunks))
    state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1"))
    out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange(
        C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out
    out = out + out_prev
    out = rearrange(out, "b c l h p -> b (c l) h p")
    if D is not None:
        if D.dim() == 1:
            D = rearrange(D, "h -> h 1")
        out = out + x * D
    return out


def get_configs():
    block_M = [64, 128, 256]
    block_N = [32, 64]
    block_K = [64, 128, 256]
    block_Dstate = [128]
    num_stages = [1, 2, 3, 4, 5]
    _configs = list(itertools.product(block_M, block_N, block_K, block_Dstate, num_stages))

    configs = [{
        'block_M': c[0],
        'block_N': c[1],
        'block_K': c[2],
        'block_Dstate': c[3],
        'num_stages': c[4],
        'threads': c[0] * 2
    } for c in _configs]
    return configs


def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, tune=False):
    dtype = "float16"
    accum_dtype = "float"
    nchunks = T.ceildiv(seqlen, chunk_size)
    p = 1.44269504

    def kernel_func(block_M, block_N, block_K, block_Dstate, num_stages, threads):

        @T.prim_func
        def main(cb: T.Buffer((batch, nchunks, ngroups, chunk_size, chunk_size), dtype),
                 x: T.Buffer((batch, seqlen, nheads, headdim), dtype), dt: T.Buffer(
                     (batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Buffer(
                         (batch, nheads, nchunks, chunk_size), dtype), C: T.Buffer(
                             (batch, seqlen, ngroups, dstate), dtype), prev_states: T.Buffer(
                                 (batch, nchunks, nheads, headdim, dstate), dtype), D: T.Buffer(
                                     (nheads), dtype), Output: T.Buffer(
                                         (batch, seqlen, nheads, headdim), dtype)):
            with T.Kernel(
                    nheads,
                    T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N),
                    batch * nchunks,
                    threads=threads) as (bz, bx, by):
                acc_o = T.alloc_fragment((block_M, block_N), accum_dtype)
                acc_o_shared = T.alloc_shared((block_M, block_N), dtype)
                cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn")
                cb_local = T.alloc_fragment((block_M, block_K), dtype)
                dA_cs_k_shared = T.alloc_shared((block_K), dtype, scope="shared")
                dA_cs_k_local = T.alloc_fragment((block_K), accum_dtype)
                dA_cs_m_local = T.alloc_fragment((block_M), accum_dtype)
                dt_shared = T.alloc_shared((block_K), dtype, scope="shared")
                dt_local = T.alloc_fragment((block_K), accum_dtype)
                x_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared.dyn")
                dA_cs_m_shared = T.alloc_shared((block_M), dtype, scope="shared")
                scale_m_local = T.alloc_fragment((block_M), accum_dtype)
                C_shared = T.alloc_shared((block_M, block_Dstate), dtype)
                prev_state_shared = T.alloc_shared((block_N, block_Dstate), dtype)
                D_local = T.alloc_fragment((1), accum_dtype)
                x_residual_shared = T.alloc_shared((block_M, block_N), dtype, scope="shared.dyn")
                x_residual_local = T.alloc_fragment((block_M, block_N), accum_dtype)

                batch_idx = by % batch
                chunk_idx = by // batch
                # m: chunk_size
                # n : headdim
                m_idx = bx // T.ceildiv(headdim, block_N)
                n_idx = bx % T.ceildiv(headdim, block_N)

                T.annotate_layout({
                    acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared),
                    cb_shared: tilelang.layout.make_swizzled_layout(cb_shared),
                    x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared)
                })

                T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M:(m_idx + 1) * block_M],
                       dA_cs_m_shared)
                T.copy(dA_cs_m_shared, dA_cs_m_local)
                T.clear(acc_o)

                for i in T.Parallel(block_M):
                    scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p)
                T.copy(
                    C[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size +
                      (m_idx + 1) * block_M, bz // (nheads // ngroups), 0:block_Dstate], C_shared)
                T.copy(
                    prev_states[batch_idx, chunk_idx, bz, n_idx * block_N:(n_idx + 1) * block_N,
                                0:block_Dstate], prev_state_shared)
                T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True)
                for i, j in T.Parallel(block_M, block_N):
                    acc_o[i, j] *= scale_m_local[i]

                loop_range = T.ceildiv((m_idx + 1) * block_M, block_K)

                for k in T.Pipelined(loop_range, num_stages=num_stages):
                    T.copy(
                        cb[batch_idx, chunk_idx, bz // (nheads // ngroups),
                           m_idx * block_M:(m_idx + 1) * block_M, k * block_K:(k + 1) * block_K],
                        cb_shared)
                    T.copy(cb_shared, cb_local)
                    T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K],
                           dA_cs_k_shared)
                    T.copy(dA_cs_k_shared, dA_cs_k_local)
                    for i, j in T.Parallel(block_M, block_K):
                        cb_local[i, j] = cb_local[i, j] * T.exp2(dA_cs_m_local[i] * p -
                                                                 dA_cs_k_local[j] * p)
                    T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared)
                    T.copy(dt_shared, dt_local)
                    for i, j in T.Parallel(block_M, block_K):
                        cb_local[i, j] *= dt_local[j]
                    for i, j in T.Parallel(block_M, block_K):
                        cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j,
                                                        cb_local[i, j], 0)
                    T.copy(
                        x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size +
                          (k + 1) * block_K, bz, n_idx * block_N:(n_idx + 1) * block_N], x_shared)
                    T.gemm(cb_local, x_shared, acc_o)

                D_local[0] = D[bz]
                T.copy(
                    x[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size +
                      (m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N],
                    x_residual_shared)
                T.copy(x_residual_shared, x_residual_local)
                for i, j in T.Parallel(block_M, block_N):
                    acc_o[i, j] += x_residual_local[i, j] * D_local[0]

                T.copy(acc_o, acc_o_shared)
                T.copy(
                    acc_o_shared,
                    Output[batch_idx, chunk_idx * chunk_size +
                           m_idx * block_M:chunk_idx * chunk_size + (m_idx + 1) * block_M, bz,
                           n_idx * block_N:(n_idx + 1) * block_N])

        return main

    if tune:

        @autotune(
            configs=get_configs(),
            keys=["block_M", "block_N", "block_K", "block_Dstate", "num_stages", "threads"],
            warmup=10,
            rep=10)
        @jit(
            out_idx=[7],
            supply_type=tilelang.TensorSupplyType.Normal,
            ref_prog=None,
            profiler="auto")
        def kernel(block_M=None,
                   block_N=None,
                   block_K=None,
                   block_Dstate=None,
                   num_stages=None,
                   threads=None):
            return kernel_func(block_M, block_N, block_K, block_Dstate, num_stages, threads)

        return kernel()
    else:

        def kernel(block_M, block_N, block_K, block_Dstate, num_stages, threads):
            return kernel_func(block_M, block_N, block_K, block_Dstate, num_stages, threads)

        return kernel


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch', type=int, default=8, help='batch size')
    parser.add_argument('--heads', type=int, default=80, help='heads')
    parser.add_argument('--groups', type=int, default=1, help='groups')
    parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
    parser.add_argument('--chunk_size', type=int, default=256, help='chunk size')
    parser.add_argument('--dim', type=int, default=64, help='dim')
    parser.add_argument('--dstate', type=int, default=128, help='dstate')
    parser.add_argument('--tune', action='store_true', help='tune configs')
    args = parser.parse_args()
    batch, heads, groups, seq_len, chunk_size, dim, dstate = args.batch, args.heads, args.groups, args.seq_len, args.chunk_size, args.dim, args.dstate
    total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate

    if (not args.tune):
        program = chunk_scan_fwd(
            batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune)(
                block_M=64, block_N=64, block_K=64, block_Dstate=128, num_stages=2, threads=128)
        mod, params = tilelang.lower(program)
        mod = Profiler(mod, params, [7], tilelang.TensorSupplyType.Normal)
        mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
        print("All checks pass.")
        latency = mod.do_bench(ref_program, warmup=500)
        print("Ref: {:.2f} ms".format(latency))
        print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
        latency = mod.do_bench(mod.func, warmup=500)
        print("Tile-lang: {:.2f} ms".format(latency))
        print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
    else:
        best_latency, best_config, _ = chunk_scan_fwd(
            batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune)
        print(f"Best latency: {best_latency}")
        print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
        print(f"Best config: {best_config}")