block_sparse_attn_triton.py 10.9 KB
Newer Older
1
# ruff: noqa: E712
2
3
4
5
6
7
8
import math
import torch

import triton
import triton.language as tl
import torch.nn.functional as F

9

10
11
12
13
14
15
16
17
def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"


def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
    bsz, num_head, downsample_len, _ = x.shape
    # N_CTX = downsample_len * BLOCK
    sparse_index = torch.topk(x, topk, dim=-1).indices
18
19
20
21
    dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len],
                            False,
                            dtype=torch.bool,
                            device=x.device)
22
23
    dense_mask.scatter_(-1, sparse_index, True)
    if use_dense_for_last_block:
24
        dense_mask[:, :, -2:, :] = True
25
    dense_mask.tril_()
26
    return dense_mask
27
28
29


def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False):
30
    dense_mask = x > threshold
31
    if use_dense_for_last_block:
32
        dense_mask[:, :, -2:, :] = True
33
    dense_mask.tril_()
34
    return dense_mask
35
36
37
38


@triton.jit
def _fwd_kernel_inner(
39
40
41
    acc,
    l_i,
    m_i,
42
43
44
    q,
    k_block_col_idx,
    block_mask_ptr,
45
46
47
48
49
50
51
    k_ptrs,
    v_ptrs,
    offs_m,
    offs_n,
    stride_kt,
    stride_vt,
    stride_bmask_n,
52
53
54
55
56
57
58
59
60
    sm_scale,
    seqlen_k,
    past_len,
    LAST_K_BLOCK: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):

    mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n)
61
62
    # print

63
64
65
66
67
68
69
70
71
72
73
74
    if mask_val == True:
        start_n = k_block_col_idx * BLOCK_N
        # -- compute qk ----

        k = tl.load(k_ptrs + start_n * stride_kt)

        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, k)

        qk *= sm_scale

        # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
75
76
77
        if LAST_K_BLOCK:
            qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0,
                           float('-inf'))
78
79
80
81
82
83
84
85

        m_ij = tl.maximum(m_i, tl.max(qk, 1))
        qk -= m_ij[:, None]
        p = tl.exp(qk)
        l_ij = tl.sum(p, 1)
        alpha = tl.exp(m_i - m_ij)
        l_i = l_i * alpha + l_ij
        acc = acc * alpha[:, None]
86

87
88
89
90
91
92
93
94
95
96
97
98
99
        # update acc
        v = tl.load(v_ptrs + start_n * stride_vt)

        p = p.to(v.type.element_ty)

        acc += tl.dot(p, v)
        # update m_i and l_i
        m_i = m_ij
    return acc, l_i, m_i


@triton.jit
def _fwd_kernel(
100
101
102
103
    Q,
    K,
    V,
    sm_scale,
104
105
    block_mask_ptr,
    Out,
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    stride_qz,
    stride_qh,
    stride_qm,
    stride_qd,
    stride_kz,
    stride_kh,
    stride_kn,
    stride_kd,
    stride_vz,
    stride_vh,
    stride_vn,
    stride_vd,
    stride_bmz,
    stride_bmh,
    stride_bmm,
    stride_bmn,
    stride_oz,
    stride_oh,
    stride_om,
    stride_od,
    H,
    N_CTX,
128
    PAST_LEN,
129
    BLOCK_M: tl.constexpr,
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    BLOCK_N: tl.constexpr,
    BLOCK_DMODEL: tl.constexpr,
):
    Q_LEN = N_CTX - PAST_LEN
    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)
    off_h = off_hz % H
    off_z = off_hz // H
    Q += off_z * stride_qz + off_h * stride_qh
    K += off_z * stride_kz + off_h * stride_kh
    V += off_z * stride_vz + off_h * stride_vh
    block_mask_ptr += off_z * stride_bmz + off_h * stride_bmh

    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_DMODEL)
    off_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd
    # off_k = offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd
    off_k = offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd
    off_v = offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd
    # Initialize pointers to Q, K, V
    q_ptrs = Q + off_q
    k_ptrs = K + off_k
    v_ptrs = V + off_v
    mask_ptrs = block_mask_ptr + start_m * stride_bmm

    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)

    q = tl.load(q_ptrs, mask=offs_m[:, None] < Q_LEN)

    k_block_start = 0
    k_block_end = tl.cdiv((start_m + 1) * BLOCK_M, BLOCK_N)

    # loop over k, v and update accumulator
    for col_idx in range(k_block_start, k_block_end):
        acc, l_i, m_i = _fwd_kernel_inner(
169
170
171
            acc,
            l_i,
            m_i,
172
173
174
            q,
            col_idx,
            mask_ptrs,
175
176
177
178
179
180
181
            k_ptrs,
            v_ptrs,
            offs_m,
            offs_n,
            stride_kn,
            stride_vn,
            stride_bmn,
182
183
184
185
186
187
188
189
190
191
192
            sm_scale,
            N_CTX,
            PAST_LEN,
            col_idx == k_block_end - 1,
            BLOCK_M,
            BLOCK_N,
        )

    m_i += tl.math.log(l_i)
    l_recip = 1 / l_i[:, None]
    acc = acc * l_recip
193
    acc = acc.to(Out.dtype.element_ty)
194

195
196
    off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[
        None, :] * stride_od
197
198
199
200
    out_ptrs = Out + off_o
    tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX)


201
202
203
204
205
206
207
208
209
210
211
def _forward(ctx,
             q,
             k,
             v,
             block_sparse_mask,
             sm_scale,
             BLOCK_M=64,
             BLOCK_N=64,
             num_warps=None,
             num_stages=1,
             out=None):
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231

    assert q.shape[-1] == k.shape[-1] == v.shape[-1]
    assert k.shape[2] == v.shape[2]
    o = out if out is not None else torch.empty_like(q).contiguous()
    grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1])

    assert q.shape[-1] in [64, 128]
    BLOCK_DMODEL = q.shape[-1]

    if is_hip():
        num_warps, num_stages = 8, 1
    else:
        num_warps, num_stages = 4, 2

    N_CTX = k.shape[2]
    PAST_LEN = N_CTX - q.shape[2]

    H = q.shape[1]

    _fwd_kernel[grid](
232
233
234
235
        q,
        k,
        v,
        sm_scale,
236
237
        block_sparse_mask,
        o,
238
239
240
241
        *q.stride(),
        *k.stride(),
        *v.stride(),
        *block_sparse_mask.stride(),
242
        *o.stride(),
243
244
        H,
        N_CTX,
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
        PAST_LEN,
        BLOCK_M,
        BLOCK_N,
        BLOCK_DMODEL,
        num_warps=num_warps,
        num_stages=num_stages,
    )

    return o


class _sparse_attention(torch.autograd.Function):

    @staticmethod
    def forward(ctx, q, k, v, block_sparse_dense, sm_scale):
        # shape constraints
        return _forward(ctx, q, k, v, block_sparse_dense, sm_scale)

    @staticmethod
    def backward(ctx, do):
        # No gradient propagation.
        raise NotImplementedError("It does not support gradient propagation yet")
        return None, None, None, None, None


270
block_sparse_triton_fn = _sparse_attention.apply
271
272
273
274
275
276
277
278
279
280
281
282
283


def test_topk_sparse_attention():
    # Config
    BATCH, N_HEADS, SEQ_LEN, D_HEAD = 1, 1, 256, 64
    TOPK = 2  # Keep top 8 elements per row
    BLOCK = 64
    torch.manual_seed(0)

    # Create inputs
    q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
    k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
    v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
284
    sm_scale = 1.0 / (D_HEAD**0.5)
285
286
287
288
289

    # Create sparse mask (downsampled to block level)
    downsample_factor = BLOCK
    downsample_len = math.ceil(SEQ_LEN / downsample_factor)
    print("downsample_len", downsample_len)
290
291
292
293
294

    x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len],
                       device='cuda',
                       dtype=torch.bfloat16)
    x_ds[:, :, :, 0] = 100
295
    print("x_ds.shape", x_ds.shape)
296
    block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
297
298
299
300
    # print("block_mask", block_mask)
    print("block_mask.shape", block_mask.shape)

    # Run Triton kernel
301
    triton_output = block_sparse_triton_fn(q, k, v, block_mask, sm_scale)
302
303
304

    # Compute reference
    # Expand block mask to full attention matrix
305
    full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda'))
306
307
    full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
    full_mask = full_mask & torch.tril(torch.ones_like(full_mask))  # Apply causal
308

309
310
311
312
313
    # PyTorch reference implementation
    attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale
    attn = attn.masked_fill(~full_mask, float('-inf'))
    attn = F.softmax(attn, dim=-1)
    ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v)
314

315
316
317
318
319
320
321
322
323
    # print("ref_output", ref_output)
    # print("triton_output", triton_output)

    # Verify accuracy
    assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \
        "Triton output doesn't match reference"
    print("Pass topk sparse attention test with qlen == klen")


324
325
326
327
328
329
def test_topk_sparse_attention_qlt_kl():
    BATCH, N_HEADS = 2, 4
    Q_LEN, K_LEN, D_HEAD = 128, 256, 64  # qlen < klen; here, past_len = 256 - 128 = 128.
    TOPK = 1
    BLOCK = 64  # block size used in downsampling
    torch.manual_seed(0)
330

331
332
333
334
335
336
    # Create inputs.
    q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
    k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
    v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
    # softmax scale
    sm_scale = 1.0 / (D_HEAD**0.5)
337

338
339
340
341
342
343
344
345
346
347
348
349
350
    downsample_factor = BLOCK
    print("downsample_factor", downsample_factor)
    downsample_len = math.ceil(K_LEN / downsample_factor)  # number of blocks along one dimension
    print("downsample_len", downsample_len)
    x_ds = torch.randn(
        BATCH, N_HEADS, downsample_len, downsample_len, device='cuda', dtype=torch.bfloat16)
    # Force the first column to be high so that the first block is always selected.
    x_ds[:, :, :, 0] = 100
    block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
    print("block_mask", block_mask)
    print("block_mask.shape", block_mask.shape)
    # Run Triton kernel.
    triton_output = block_sparse_triton_fn(q, k, v, block_mask, sm_scale)
351

352
    past_len = K_LEN - Q_LEN
353

354
    attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale
355

356
357
    full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')).bool()
    full_mask_full = full_mask_full[..., :K_LEN, :K_LEN]
358

359
    effective_mask = full_mask_full[..., past_len:K_LEN, :]  # shape: (B, H, Q_LEN, K_LEN)
360

361
362
363
    i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1)  # shape: (Q_LEN, 1)
    j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0)  # shape: (1, K_LEN)
    causal_mask = (j_global <= i_global)  # shape: (Q_LEN, K_LEN)
364

365
    final_mask = effective_mask & causal_mask  # shape: (B, H, Q_LEN, K_LEN)
366

367
368
369
    attn = attn.masked_fill(~final_mask, float('-inf'))
    attn = F.softmax(attn, dim=-1)
    ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v)
370

371
372
373
    # Verify accuracy.
    assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \
        "Triton output doesn't match reference when qlen < klen"
374

375
    print("Pass topk sparse attention test with qlen < klen")
376
377


378
def main():
379
    test_topk_sparse_attention()
380
    test_topk_sparse_attention_qlt_kl()
381
382
383
384


if __name__ == "__main__":
    main()