"vscode:/vscode.git/clone" did not exist on "c9ee3d3559717dd7a92616315b1f997dd6ba7acc"
test_sparse_flash_attn.py 17 KB
Newer Older
1
2
3
4
5
6
import math
from typing import List, Optional, Tuple

import pytest
import torch
from einops import rearrange, repeat
7
8
9
10
11
from sgl_kernel.sparse_flash_attn import (
    convert_vertical_slash_indexes,
    convert_vertical_slash_indexes_mergehead,
    sparse_attn_func,
)
12
from test_flash_attention import construct_local_mask, is_fa3_supported
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173


def ref_attn(
    q,
    k,
    v,
    query_padding_mask=None,
    key_padding_mask=None,
    attn_bias=None,
    dropout_p=0.0,
    dropout_mask=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite window size
    softcap=0.0,
    upcast=True,
    reorder_ops=False,
    key_leftpad=None,
):
    """
    Arguments:
        q: (batch_size, seqlen_q, nheads, head_dim)
        k: (batch_size, seqlen_k, nheads_k, head_dim)
        v: (batch_size, seqlen_k, nheads_k, head_dim)
        query_padding_mask: (batch_size, seqlen_q)
        key_padding_mask: (batch_size, seqlen_k)
        attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
        dropout_p: float
        dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
        causal: whether to apply causal masking
        window_size: (int, int), left and right window size
        upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
            output back to fp16/bf16.
        reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)
            without changing the math. This is to estimate the numerical error from operation
            reordering.
    Output:
        output: (batch_size, seqlen_q, nheads, head_dim)
        lse: (batch_size, nheads, seqlen_q)
    """
    if causal:
        window_size = (window_size[0], 0)
    dtype_og = q.dtype
    if upcast:
        q, k, v = q.float(), k.float(), v.float()
    seqlen_q, seqlen_k = q.shape[1], k.shape[1]
    k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
    v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
    d = q.shape[-1]
    if not reorder_ops:
        scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
    else:
        scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))

    lse_ref = scores.logsumexp(dim=-1)

    if softcap > 0:
        scores = scores / softcap
        scores = scores.tanh()
        scores = scores * softcap
    if key_padding_mask is not None:
        scores.masked_fill_(
            rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")
        )
    if window_size[0] >= 0 or window_size[1] >= 0:
        local_mask = construct_local_mask(
            seqlen_q,
            seqlen_k,
            window_size,
            query_padding_mask,
            key_padding_mask,
            q.device,
            key_leftpad=key_leftpad,
        )
        scores.masked_fill_(local_mask, float("-inf"))
    if attn_bias is not None:
        scores = scores + attn_bias
    attention = torch.softmax(scores, dim=-1).to(v.dtype)
    # Some rows might be completely masked out so we fill them with zero instead of NaN
    if window_size[0] >= 0 or window_size[1] >= 0:
        attention = attention.masked_fill(
            torch.all(local_mask, dim=-1, keepdim=True), 0.0
        )
    # We want to mask here so that the attention matrix doesn't have any NaNs
    # Otherwise we'll get NaN in dV
    if query_padding_mask is not None:
        attention = attention.masked_fill(
            rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0
        )
    dropout_scaling = 1.0 / (1 - dropout_p)
    # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
    # output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
    if dropout_mask is not None:
        attention_drop = attention.masked_fill(~dropout_mask, 0.0)
    else:
        attention_drop = attention
    output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
    if query_padding_mask is not None:
        output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)

    return output.to(dtype=dtype_og), lse_ref


def ref_paged_attn(
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    query_lens: List[int],
    kv_lens: List[int],
    block_tables: torch.Tensor,
    scale: float,
    sliding_window: Optional[int] = None,
    soft_cap: Optional[float] = None,
) -> torch.Tensor:
    num_seqs = len(query_lens)
    block_tables = block_tables.cpu().numpy()
    _, block_size, num_kv_heads, head_size = key_cache.shape

    outputs: List[torch.Tensor] = []
    start_idx = 0
    for i in range(num_seqs):
        query_len = query_lens[i]
        kv_len = kv_lens[i]
        # clone to avoid clobbering the query tensor
        q = query[start_idx : start_idx + query_len].clone()
        q *= scale

        num_kv_blocks = (kv_len + block_size - 1) // block_size
        block_indices = block_tables[i, :num_kv_blocks]

        k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
        k = k[:kv_len]
        v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
        v = v[:kv_len]

        if q.shape[1] != k.shape[1]:
            k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
            v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
        attn = torch.einsum("qhd,khd->hqk", q, k).float()
        empty_mask = torch.ones(query_len, kv_len)
        mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
        if sliding_window is not None:
            sliding_window_mask = (
                torch.triu(
                    empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1
                )
                .bool()
                .logical_not()
            )
            mask |= sliding_window_mask
        if soft_cap is not None:
            attn = soft_cap * torch.tanh(attn / soft_cap)
        attn.masked_fill_(mask, float("-inf"))
        attn = torch.softmax(attn, dim=-1).to(v.dtype)
        out = torch.einsum("hqk,khd->qhd", attn, v)

        outputs.append(out)
        start_idx += query_len

    return torch.cat(outputs, dim=0)


174
175
176
177
@pytest.mark.skipif(
    not is_fa3_supported(),
    reason="flash_attn at sgl-kernel is only supported on sm90 or sm80",
)
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize(
    "seq_lens",
    [
        (1, 1),
        (1, 1024),
        (1, 2048),
        (1023, 2049),
        (1023, 1023),
        (32, 32),
        (65, 65),
        (129, 129),
    ],
)
@pytest.mark.parametrize("num_heads", [1, 2, 4])
@pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("NNZ_S", [0, 1, 2, 3, 7, 15, 32])
@torch.inference_mode()
def test_sparse_attention(
    batch_size,
    seq_lens,
    num_heads,
    head_size,
    dtype,
    NNZ_S,
) -> None:
    torch.set_default_device("cuda")
    torch.cuda.manual_seed_all(0)
    block_size_M = 64
    block_size_N = 64
    seqlen_q, seqlen_k = seq_lens
    q = torch.randn(
        batch_size, seqlen_q, num_heads, head_size, dtype=dtype, requires_grad=False
    )
    k = torch.randn(
        batch_size, seqlen_k, num_heads, head_size, dtype=dtype, requires_grad=False
    )
    v = torch.randn(
        batch_size, seqlen_k, num_heads, head_size, dtype=dtype, requires_grad=False
    )
    NUM_ROWS = (seqlen_q + block_size_M - 1) // block_size_M
    if NNZ_S * block_size_N > seqlen_k:
        return
    NNZ_V = seqlen_k - NNZ_S * block_size_N
    block_count = torch.tensor(
        [NNZ_S] * batch_size * NUM_ROWS * num_heads, dtype=torch.int32
    ).reshape(batch_size, num_heads, NUM_ROWS)
    column_count = torch.tensor(
        [NNZ_V] * batch_size * NUM_ROWS * num_heads, dtype=torch.int32
    ).reshape(batch_size, num_heads, NUM_ROWS)
    block_offset = torch.tensor(
        [[i * block_size_N for i in range(NNZ_S)]] * batch_size * NUM_ROWS * num_heads,
        dtype=torch.int32,
    ).reshape(batch_size, num_heads, NUM_ROWS, NNZ_S)
    column_index = torch.tensor(
        [[NNZ_S * block_size_N + i for i in range(NNZ_V)]]
        * batch_size
        * NUM_ROWS
        * num_heads,
        dtype=torch.int32,
    ).reshape(batch_size, num_heads, NUM_ROWS, NNZ_V)
    out, lse = sparse_attn_func(
        q,
        k,
        v,
        block_count,
        block_offset,
        column_count,
        column_index,
        return_softmax_lse=True,
    )

    ref_out, ref_lse = ref_attn(q, k, v)

    torch.testing.assert_close(
        out, ref_out, atol=2e-2, rtol=1e-2
    ), f"{torch.max(torch.abs(out - ref_out))}"
    torch.testing.assert_close(
        lse, ref_lse, atol=2e-2, rtol=1e-2
    ), f"{torch.max(torch.abs(lse - ref_lse))}"


261
262
# sparse attention utils
# origin
263
264
265
266
@pytest.mark.skipif(
    not is_fa3_supported(),
    reason="flash_attn at sgl-kernel is only supported on sm90 or sm80",
)
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
@pytest.mark.parametrize("causal", [True, False])
def test_convert_vertical_slash_indexes(causal):
    # Prepare small, hand-checkable inputs
    q_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda")  # [BATCH]
    kv_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda")
    vertical_indexes = torch.tensor(
        [[[1, 3]]], dtype=torch.int32, device="cuda"
    )  # [BATCH, N_HEADS, NNZ_V]
    slash_indexes = torch.tensor(
        [[[2]]], dtype=torch.int32, device="cuda"
    )  # [BATCH, N_HEADS, NNZ_S]
    context_size = 4
    block_size_M = 2
    block_size_N = 2

    # Call your CUDA kernel wrapper
    block_count, block_offset, column_count, column_index = (
        convert_vertical_slash_indexes(
            q_seqlens,
            kv_seqlens,
            vertical_indexes,
            slash_indexes,
            context_size,
            block_size_M,
            block_size_N,
            causal=causal,
        )
    )

    # Manually create expected outputs for this input
    # There are 2 rows (blocks): row0 (tokens 0-1), row1 (tokens 2-3)
    # Fill these expected tensors based on your CUDA kernel's logic
    # For demonstration, we assume:
    # - block_count: how many slash indices fall into each block
    # - block_offset: the value of those indices
    # - column_count: number of valid vertical indices per block
    # - column_index: the actual vertical indices

    expected_column_index = torch.tensor(
        [[[[0, 0], [0, 0]]]], dtype=torch.int32, device="cuda"
    )

    # If causal=False, update these tensors according to expected behavior
    if not causal:
        # Update these values if your kernel produces different output in non-causal mode
        expected_column_index = torch.tensor(
            [[[[1, 0], [1, 3]]]], dtype=torch.int32, device="cuda"
        )

    # Assert that outputs match expectations
    assert torch.equal(column_index, expected_column_index)


# mergehead
321
322
323
324
@pytest.mark.skipif(
    not is_fa3_supported(),
    reason="flash_attn at sgl-kernel is only supported on sm90 or sm80",
)
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
@pytest.mark.parametrize("causal", [True, False])
def test_convert_vertical_slash_indexes_mergehead(causal):
    # Prepare small, hand-checkable inputs for mergehead version
    q_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda")
    kv_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda")
    vertical_indexes = torch.tensor(
        [
            [
                [1, 3],  # head 0
                [2, 0],  # head 1
            ]
        ],
        dtype=torch.int32,
        device="cuda",
    )  # [BATCH, N_HEADS, NNZ_V]
    slash_indexes = torch.tensor(
        [
            [
                [2, 0],  # head 0
                [1, 3],  # head 1
            ]
        ],
        dtype=torch.int32,
        device="cuda",
    )  # [BATCH, N_HEADS, NNZ_S]
    vertical_indices_count = torch.tensor([2, 1], dtype=torch.int32, device="cuda")
    slash_indices_count = torch.tensor([1, 2], dtype=torch.int32, device="cuda")
    context_size = 4
    block_size_M = 2
    block_size_N = 2

    # Call your CUDA kernel wrapper
    block_count, block_offset, column_count, column_index = (
        convert_vertical_slash_indexes_mergehead(
            q_seqlens,
            kv_seqlens,
            vertical_indexes,
            slash_indexes,
            vertical_indices_count,
            slash_indices_count,
            context_size,
            block_size_M,
            block_size_N,
            causal=causal,
        )
    )

    # Manually create expected outputs for this input
    # For demonstration, assume:
    # - batch=1, head=2, num_rows=2, nnz_v=2, nnz_s=2
    # Fill these expected tensors according to your kernel's behavior

    expected_column_index = torch.tensor(
        [[[[1, 0], [1, 3]], [[-1079459945, -1077788999], [-1080050043, -1104625879]]]],
        dtype=torch.int32,
        device="cuda",
    )

    if not causal:
        # If non-causal mode output is different, update these values
        expected_column_index = torch.tensor(
            [[[[1, 0], [1, 3]], [[2, -1077788999], [2, -1104625879]]]],
            dtype=torch.int32,
            device="cuda",
        )

    # Assert that outputs match expectations
    assert torch.equal(column_index, expected_column_index)


# skip cause use fa2 for test
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
# @pytest.mark.parametrize("seq_lens", [[(1024, 1328)],
#                                     [(1024, 1328), (1, 2048)],
#                                     [(1025, 1328), (2, 2048)],
#                                     [(1025, 2049), (2, 1281)],
#                                     ])
# @pytest.mark.parametrize("head_size", [128])
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
# @torch.inference_mode()
# def test_sparse_attention_varlen(
#         seq_lens,
#         head_size,
#         dtype,
# ) -> None:
#     torch.set_default_device("cuda")
#     torch.cuda.manual_seed_all(0)
#     block_size_M = 64
#     block_size_N = 64
#     num_seqs = len(seq_lens)
#     query_lens = [x[0] for x in seq_lens]
#     kv_lens = [x[1] for x in seq_lens]
#     num_heads = 1
#     query = torch.randn(sum(query_lens),
#                         num_heads,
#                         head_size,
#                         dtype=dtype)
#     key = torch.randn(sum(kv_lens),
#                     num_heads,
#                     head_size,
#                     dtype=dtype)
#     value = torch.randn_like(key)
#     cu_query_lens = torch.tensor([0] + query_lens,
#                                 dtype=torch.int32).cumsum(dim=0,
#                                                         dtype=torch.int32)
#     cu_kv_lens = torch.tensor([0] + kv_lens,
#                                 dtype=torch.int32).cumsum(dim=0,
#                                                         dtype=torch.int32)
#     max_query_len = max(query_lens)
#     max_kv_len = max(kv_lens)

#     NUM_ROWS = (max_query_len + block_size_M - 1) // block_size_M
#     NNZ_S = 20
#     NNZ_V = 2048
#     batch_size = len(query_lens)

#     block_counts = []
#     column_counts = []
#     block_offsets = []
#     column_indices = []
#     for b in range(batch_size):
#         block_counts.append(torch.tensor([NNZ_S] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS))
#         columns = kv_lens[b] - NNZ_S * block_size_N
#         column_counts.append(torch.tensor([columns] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS))
#         block_offsets.append(torch.tensor([[i * block_size_N for i in range(NNZ_S)]] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS, NNZ_S))
#         column_indices.append(torch.tensor([[NNZ_S * block_size_N + i for i in range(NNZ_V)]] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS, NNZ_V))
#     block_count = torch.concat(block_counts).reshape(batch_size, num_heads, NUM_ROWS)
#     column_count = torch.concat(column_counts).reshape(batch_size, num_heads, NUM_ROWS)
#     block_offset = torch.concat(block_offsets).reshape(batch_size, num_heads, NUM_ROWS, NNZ_S)
#     column_index = torch.concat(column_indices).reshape(batch_size, num_heads, NUM_ROWS, NNZ_V)
#     out, lse = sparse_attn_varlen_func(
#         query,
#         key,
#         value,
#         block_count,
#         block_offset,
#         column_count,
#         column_index,
#         cu_seqlens_q=cu_query_lens,
#         cu_seqlens_k=cu_kv_lens,
#         max_seqlen_q=max_query_len,
#         max_seqlen_k=max_kv_len,
#         return_softmax_lse=True,
#     )

#     max_num_blocks_per_seq = (max_kv_len + 2048 - 1) // 2048
#     block_tables = torch.randint(0,
#                                  2048,
#                                  (len(query_lens), max_num_blocks_per_seq),
#                                  dtype=torch.int32)
#     scale = head_size**-0.5

#     ref_out, ref_lse, _ = ref_paged_attn(
#         query,
#         key,
#         value,
#         query_lens=query_lens,
#         kv_lens=kv_lens,
#         block_tables=block_tables,
#         scale=scale
#     )

#     torch.testing.assert_close(out, ref_out, atol=2e-2, rtol=1e-2), \
#         f"{torch.max(torch.abs(out - ref_out))}"
#     torch.testing.assert_close(lse, ref_lse, atol=2e-2, rtol=1e-2), \
#         f"{torch.max(torch.abs(lse - ref_lse))}"

if __name__ == "__main__":
    pytest.main([__file__])