block_sparse_attn_triton.py 10.6 KB
Newer Older
1
# ruff: noqa: E712
2
3
4
5
6
7
8
import math
import torch

import triton
import triton.language as tl
import torch.nn.functional as F

9

10
11
12
13
14
15
16
17
def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"


def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
    bsz, num_head, downsample_len, _ = x.shape
    # N_CTX = downsample_len * BLOCK
    sparse_index = torch.topk(x, topk, dim=-1).indices
18
    dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
19
20
    dense_mask.scatter_(-1, sparse_index, True)
    if use_dense_for_last_block:
21
        dense_mask[:, :, -2:, :] = True
22
    dense_mask.tril_()
23
    return dense_mask
24
25
26


def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False):
27
    dense_mask = x > threshold
28
    if use_dense_for_last_block:
29
        dense_mask[:, :, -2:, :] = True
30
    dense_mask.tril_()
31
    return dense_mask
32
33
34
35


@triton.jit
def _fwd_kernel_inner(
36
37
38
    acc,
    l_i,
    m_i,
39
40
41
    q,
    k_block_col_idx,
    block_mask_ptr,
42
43
44
45
46
47
48
    k_ptrs,
    v_ptrs,
    offs_m,
    offs_n,
    stride_kt,
    stride_vt,
    stride_bmask_n,
49
50
51
52
53
54
55
56
    sm_scale,
    seqlen_k,
    past_len,
    LAST_K_BLOCK: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n)
57
58
    # print

59
60
61
62
63
64
65
66
67
68
69
70
    if mask_val == True:
        start_n = k_block_col_idx * BLOCK_N
        # -- compute qk ----

        k = tl.load(k_ptrs + start_n * stride_kt)

        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, k)

        qk *= sm_scale

        # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
71
        if LAST_K_BLOCK:
72
            qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf"))
73
74
75
76
77
78
79
80

        m_ij = tl.maximum(m_i, tl.max(qk, 1))
        qk -= m_ij[:, None]
        p = tl.exp(qk)
        l_ij = tl.sum(p, 1)
        alpha = tl.exp(m_i - m_ij)
        l_i = l_i * alpha + l_ij
        acc = acc * alpha[:, None]
81

82
83
84
85
86
87
88
89
90
91
92
93
94
        # update acc
        v = tl.load(v_ptrs + start_n * stride_vt)

        p = p.to(v.type.element_ty)

        acc += tl.dot(p, v)
        # update m_i and l_i
        m_i = m_ij
    return acc, l_i, m_i


@triton.jit
def _fwd_kernel(
95
96
97
98
    Q,
    K,
    V,
    sm_scale,
99
100
    block_mask_ptr,
    Out,
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    stride_qz,
    stride_qh,
    stride_qm,
    stride_qd,
    stride_kz,
    stride_kh,
    stride_kn,
    stride_kd,
    stride_vz,
    stride_vh,
    stride_vn,
    stride_vd,
    stride_bmz,
    stride_bmh,
    stride_bmm,
    stride_bmn,
    stride_oz,
    stride_oh,
    stride_om,
    stride_od,
    H,
    N_CTX,
123
    PAST_LEN,
124
    BLOCK_M: tl.constexpr,
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    BLOCK_N: tl.constexpr,
    BLOCK_DMODEL: tl.constexpr,
):
    Q_LEN = N_CTX - PAST_LEN
    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)
    off_h = off_hz % H
    off_z = off_hz // H
    Q += off_z * stride_qz + off_h * stride_qh
    K += off_z * stride_kz + off_h * stride_kh
    V += off_z * stride_vz + off_h * stride_vh
    block_mask_ptr += off_z * stride_bmz + off_h * stride_bmh

    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_DMODEL)
    off_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd
    # off_k = offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd
    off_k = offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd
    off_v = offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd
    # Initialize pointers to Q, K, V
    q_ptrs = Q + off_q
    k_ptrs = K + off_k
    v_ptrs = V + off_v
    mask_ptrs = block_mask_ptr + start_m * stride_bmm

152
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
153
154
155
156
157
158
159
160
161
162
163
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)

    q = tl.load(q_ptrs, mask=offs_m[:, None] < Q_LEN)

    k_block_start = 0
    k_block_end = tl.cdiv((start_m + 1) * BLOCK_M, BLOCK_N)

    # loop over k, v and update accumulator
    for col_idx in range(k_block_start, k_block_end):
        acc, l_i, m_i = _fwd_kernel_inner(
164
165
166
            acc,
            l_i,
            m_i,
167
168
169
            q,
            col_idx,
            mask_ptrs,
170
171
172
173
174
175
176
            k_ptrs,
            v_ptrs,
            offs_m,
            offs_n,
            stride_kn,
            stride_vn,
            stride_bmn,
177
178
179
180
181
182
183
184
185
186
187
            sm_scale,
            N_CTX,
            PAST_LEN,
            col_idx == k_block_end - 1,
            BLOCK_M,
            BLOCK_N,
        )

    m_i += tl.math.log(l_i)
    l_recip = 1 / l_i[:, None]
    acc = acc * l_recip
188
    acc = acc.to(Out.dtype.element_ty)
189

190
    off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od
191
192
193
194
    out_ptrs = Out + off_o
    tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX)


195
def _forward(ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None):
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    assert q.shape[-1] == k.shape[-1] == v.shape[-1]
    assert k.shape[2] == v.shape[2]
    o = out if out is not None else torch.empty_like(q).contiguous()
    grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1])

    assert q.shape[-1] in [64, 128]
    BLOCK_DMODEL = q.shape[-1]

    if is_hip():
        num_warps, num_stages = 8, 1
    else:
        num_warps, num_stages = 4, 2

    N_CTX = k.shape[2]
    PAST_LEN = N_CTX - q.shape[2]

    H = q.shape[1]

    _fwd_kernel[grid](
215
216
217
218
        q,
        k,
        v,
        sm_scale,
219
220
        block_sparse_mask,
        o,
221
222
223
224
        *q.stride(),
        *k.stride(),
        *v.stride(),
        *block_sparse_mask.stride(),
225
        *o.stride(),
226
227
        H,
        N_CTX,
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
        PAST_LEN,
        BLOCK_M,
        BLOCK_N,
        BLOCK_DMODEL,
        num_warps=num_warps,
        num_stages=num_stages,
    )

    return o


class _sparse_attention(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, block_sparse_dense, sm_scale):
        # shape constraints
        return _forward(ctx, q, k, v, block_sparse_dense, sm_scale)

    @staticmethod
    def backward(ctx, do):
        # No gradient propagation.
        raise NotImplementedError("It does not support gradient propagation yet")
        return None, None, None, None, None


252
block_sparse_triton_fn = _sparse_attention.apply
253
254
255
256
257
258
259
260
261
262


def test_topk_sparse_attention():
    # Config
    BATCH, N_HEADS, SEQ_LEN, D_HEAD = 1, 1, 256, 64
    TOPK = 2  # Keep top 8 elements per row
    BLOCK = 64
    torch.manual_seed(0)

    # Create inputs
263
264
265
    q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
    k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
    v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
266
    sm_scale = 1.0 / (D_HEAD**0.5)
267
268
269
270
271

    # Create sparse mask (downsampled to block level)
    downsample_factor = BLOCK
    downsample_len = math.ceil(SEQ_LEN / downsample_factor)
    print("downsample_len", downsample_len)
272

273
    x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16)
274
    x_ds[:, :, :, 0] = 100
275
    print("x_ds.shape", x_ds.shape)
276
    block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
277
278
279
280
    # print("block_mask", block_mask)
    print("block_mask.shape", block_mask.shape)

    # Run Triton kernel
281
    triton_output = block_sparse_triton_fn(q, k, v, block_mask, sm_scale)
282
283
284

    # Compute reference
    # Expand block mask to full attention matrix
285
    full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda"))
286
287
    full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
    full_mask = full_mask & torch.tril(torch.ones_like(full_mask))  # Apply causal
288

289
    # PyTorch reference implementation
290
291
    attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale
    attn = attn.masked_fill(~full_mask, float("-inf"))
292
    attn = F.softmax(attn, dim=-1)
293
    ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v)
294

295
296
297
298
    # print("ref_output", ref_output)
    # print("triton_output", triton_output)

    # Verify accuracy
299
    assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference"
300
301
302
    print("Pass topk sparse attention test with qlen == klen")


303
304
305
306
307
308
def test_topk_sparse_attention_qlt_kl():
    BATCH, N_HEADS = 2, 4
    Q_LEN, K_LEN, D_HEAD = 128, 256, 64  # qlen < klen; here, past_len = 256 - 128 = 128.
    TOPK = 1
    BLOCK = 64  # block size used in downsampling
    torch.manual_seed(0)
309

310
    # Create inputs.
311
312
313
    q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
    k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
    v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
314
315
    # softmax scale
    sm_scale = 1.0 / (D_HEAD**0.5)
316

317
318
319
320
    downsample_factor = BLOCK
    print("downsample_factor", downsample_factor)
    downsample_len = math.ceil(K_LEN / downsample_factor)  # number of blocks along one dimension
    print("downsample_len", downsample_len)
321
    x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="cuda", dtype=torch.bfloat16)
322
323
324
325
326
327
328
    # Force the first column to be high so that the first block is always selected.
    x_ds[:, :, :, 0] = 100
    block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
    print("block_mask", block_mask)
    print("block_mask.shape", block_mask.shape)
    # Run Triton kernel.
    triton_output = block_sparse_triton_fn(q, k, v, block_mask, sm_scale)
329

330
    past_len = K_LEN - Q_LEN
331

332
    attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale
333

334
    full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")).bool()
335
    full_mask_full = full_mask_full[..., :K_LEN, :K_LEN]
336

337
    effective_mask = full_mask_full[..., past_len:K_LEN, :]  # shape: (B, H, Q_LEN, K_LEN)
338

339
340
    i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1)  # shape: (Q_LEN, 1)
    j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0)  # shape: (1, K_LEN)
341
    causal_mask = j_global <= i_global  # shape: (Q_LEN, K_LEN)
342

343
    final_mask = effective_mask & causal_mask  # shape: (B, H, Q_LEN, K_LEN)
344

345
    attn = attn.masked_fill(~final_mask, float("-inf"))
346
    attn = F.softmax(attn, dim=-1)
347
    ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v)
348

349
    # Verify accuracy.
350
    assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference when qlen < klen"
351

352
    print("Pass topk sparse attention test with qlen < klen")
353
354


355
def main():
356
    test_topk_sparse_attention()
357
    test_topk_sparse_attention_qlt_kl()
358
359
360
361


if __name__ == "__main__":
    main()