"configs/datasets/qaspercut/qaspercut_gen_a2d88a.py" did not exist on "c94cc943485e275897ad95cfa5192ff8e066378a"
benchmark_mha_sink_fwd.py 6.52 KB
Newer Older
1
2
3
4
5
6
7
import torch
import argparse
from tilelang.profiler import do_bench
import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
from example_mha_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs
8
from typing import Optional
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52


@triton.jit
def triton_kernel(
    Q,
    K,
    V,
    Sinks,
    sm_scale,
    Out,
    Z,
    H,
    N_Q_CTX,
    N_KV_CTX,
    HEAD_DIM: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BANDWIDTH: tl.constexpr,
    start_q: tl.constexpr,
):
    tl.static_assert(BLOCK_N <= HEAD_DIM)
    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)
    off_z = off_hz // H
    off_h = off_hz % H

    # load attention sinks
    if Sinks is not None:  # noqa: SIM108
        sink = tl.load(Sinks + off_h).to(tl.float32)
    else:
        sink = 0

    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    # initialize pointer to m and l
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) + sink
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
    # load scales
    qk_scale = sm_scale
    q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM])

    if BANDWIDTH:
53
        lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    else:
        lo, hi = 0, start_q + (start_m + 1) * BLOCK_M

    for start_n in range(lo, hi, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)

        mask = (start_n + offs_n)[None, :] > (start_q + offs_m)[:, None]

        if BANDWIDTH:
            too_old = (start_n + offs_n[None, :]) < (start_q + offs_m[:, None] - BANDWIDTH + 1)
            mask = mask | too_old

        k = K.load([off_z, off_h, start_n, 0]).reshape([BLOCK_N, HEAD_DIM]).T
        qk = tl.dot(q, k, allow_tf32=False)

        qk = qk * qk_scale + tl.where(mask, -1.0e6, 0.0)
        m_ij = tl.maximum(m_i, tl.max(qk, 1))
        qk -= m_ij[:, None]

        p = tl.math.exp(qk)
        alpha = tl.math.exp(m_i - m_ij)
        l_ij = tl.sum(p, 1)
        acc = acc * alpha[:, None]

        v = V.load([off_z, off_h, start_n, 0]).reshape([BLOCK_N, HEAD_DIM])
        # v = v.to(tl.float32)
        p = p.to(v.dtype)  # We perform fp16 gemm to utilize tensor core
        acc = tl.dot(p, v, acc, allow_tf32=False)

        l_i = l_i * alpha + l_ij
        m_i = m_ij

    sink = tl.math.exp(sink - m_i)
    z = l_i + sink
    acc = acc / z[:, None]
    # m_i += tl.math.log(l_i)
    # m_ptrs = M + off_hz * N_Q_CTX + offs_m
    # tl.store(m_ptrs, m_i)
    acc = acc.to(Out.dtype)[None, None, :, :]
    Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)


96
def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.Tensor:
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    bs, n_heads, seq_q, head_dim = Q.shape
    seq_kv = K.shape[2]
    BLOCK_M = 64
    BLOCK_N = 64

    o = torch.empty_like(Q)
    grid = (triton.cdiv(seq_q, BLOCK_M), bs * n_heads, 1)
    triton_kernel[grid](
        TensorDescriptor.from_tensor(Q, [1, 1, BLOCK_M, head_dim]),
        TensorDescriptor.from_tensor(K, [1, 1, BLOCK_N, head_dim]),
        TensorDescriptor.from_tensor(V, [1, 1, BLOCK_N, head_dim]),
        Sinks,
        1.0 / head_dim**0.5,
        TensorDescriptor.from_tensor(o, [1, 1, BLOCK_M, head_dim]),
        bs,
        n_heads,
        N_Q_CTX=seq_q,
        N_KV_CTX=seq_kv,
        HEAD_DIM=head_dim,
        BANDWIDTH=window_size,
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
119
120
        start_q=seq_kv - seq_q,
    )
121
122
123
    return o


124
125
126
127
128
129
130
131
132
133
def main(
    batch: int = 1,
    heads: int = 32,
    seq_q: int = 256,
    seq_kv: int = 256,
    dim: int = 128,
    window_size: Optional[int] = None,
    dtype: str = "float16",
    tune: bool = False,
):
134
135
    torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
    if window_size is not None:
136
        print("Using sliding window attention.")
137
        assert window_size <= seq_q
138
        flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim  # just a rough estimation
139
    else:
140
        print("Using full attention.")
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
    total_flops = 2 * flops_per_matmul

    if tune:
        kernel = flashattn(batch, heads, seq_q, seq_kv, dim, window_size, dtype=dtype)
        print(f"Best latency: {kernel.latency}")
        print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}")
        print(f"Best config: {kernel.config}")
    else:
        block_M = 128
        block_N = 128
        num_stages = 2
        threads = 256
        print(f"{block_M=}, {block_N=}, {num_stages=}, {threads=}")

        kernel = flashattn(
            batch,
            heads,
            seq_q,
            seq_kv,
            dim,
            window_size,
            block_M=block_M,
            block_N=block_N,
            num_stages=num_stages,
            threads=threads,
167
168
            dtype=dtype,
        )
169
170
171
172

        Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype)

        torch.testing.assert_close(
173
174
            kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2
        )
175
176
177
178
179
180
181
182
183
184
185
186
        print("All checks passed.✅")

        latency = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500)
        print("Triton: {:.2f} ms".format(latency))
        print("Triton: {:.2f} TFlops".format(total_flops / latency * 1e-9))
        latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
        print("Tilelang: {:.2f} ms".format(latency))
        print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
187
188
189
190
191
192
193
194
    parser.add_argument("--batch", type=int, default=8, help="batch size")
    parser.add_argument("--heads", type=int, default=32, help="heads")
    parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query")
    parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value")
    parser.add_argument("--dim", type=int, default=128, help="dim")
    parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
    parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16")
    parser.add_argument("--tune", action="store_true", help="tune")
195
    args = parser.parse_args()
196
    main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune)