test_triton_attention_kernels.py 27.6 KB
Newer Older
1
2
3
4
import random
import unittest

import torch
5
import torch.nn.functional as F
6

Ke Bao's avatar
Ke Bao committed
7
8
9
10
11
from sglang.srt.layers.attention.triton_ops.decode_attention import (
    decode_attention_fwd,
    decode_attention_fwd_grouped,
    decode_attention_fwd_normal,
)
12
from sglang.srt.layers.attention.triton_ops.extend_attention import (
13
    build_unified_kv_indices,
14
    extend_attention_fwd,
15
    extend_attention_fwd_unified,
16
17
    redundant_attention,
)
18
19
20
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
    context_attention_fwd,
)
21
from sglang.test.test_utils import CustomTestCase
22
23


24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def extend_attention_fwd_torch(
    q: torch.Tensor,  # [extend_tokens, H_Q, D]
    k: torch.Tensor,  # [extend_tokens, H_KV, D]
    v: torch.Tensor,  # [extend_tokens, H_KV, D]
    o: torch.Tensor,  # [extend_tokens, H_Q, D]
    k_cache: torch.Tensor,  # [total_tokens, H_KV, D]
    v_cache: torch.Tensor,  # [total_tokens, H_KV, D]
    qo_indptr: torch.Tensor,  # [B+1]
    kv_indptr: torch.Tensor,  # [B+1]
    kv_indices: torch.Tensor,  # [prefix_tokens]
    sliding_window_size: int,
):
    B = qo_indptr.size(0) - 1
    _, H_Q, D = q.shape
    _, H_KV, _ = k.shape

    group_size = H_Q // H_KV
    scale = 1.0 / D**0.5

    for i in range(B):
        q_start = int(qo_indptr[i].item())
        q_end = int(qo_indptr[i + 1].item())
        kv_start = int(kv_indptr[i].item())
        kv_end = int(kv_indptr[i + 1].item())

        prefix_indices = kv_indices[kv_start:kv_end]
        k_prefix = k_cache[prefix_indices]  # [prefix_len, H_KV, D]
        v_prefix = v_cache[prefix_indices]  # [prefix_len, H_KV, D]

        k_extend = k[q_start:q_end]  # [extend_len, H_KV, D]
        v_extend = v[q_start:q_end]  # [extend_len, H_KV, D]
        q_extend = q[q_start:q_end]  # [extend_len, H_Q,  D]

        k_full = torch.cat([k_prefix, k_extend], dim=0)  # [total_len, H_KV, D]
        v_full = torch.cat([v_prefix, v_extend], dim=0)  # [total_len, H_KV, D]

        if group_size != 1:
            k_full_hq = k_full.repeat_interleave(
                group_size, dim=1
            )  # [total_len, H_Q, D]
            v_full_hq = v_full.repeat_interleave(
                group_size, dim=1
            )  # [total_len, H_Q, D]
        else:
            k_full_hq = k_full
            v_full_hq = v_full

        prefix_len = k_prefix.size(0)
        extend_len = k_extend.size(0)
        total_len = prefix_len + extend_len

        # causal
        pos_keys = torch.arange(total_len, device=q.device)
        t = prefix_len + torch.arange(extend_len, device=q.device)  # [extend_len]
        causal_mask = pos_keys.unsqueeze(0) <= t.unsqueeze(1)

        # sliding window
        if sliding_window_size is not None and sliding_window_size > 0:
            start = (t - (sliding_window_size)).clamp_min(0)  # [extend_len]
        else:
            start = torch.zeros_like(t)
        window_mask = pos_keys.unsqueeze(0) >= start.unsqueeze(1)

        final_mask = causal_mask & window_mask

        attn_scores = (
            torch.einsum("qhd,khd->qhk", q_extend, k_full_hq) * scale
        )  # [extend_len, H_Q, total_len]
        attn_scores = attn_scores.masked_fill(~final_mask.unsqueeze(1), float("-inf"))

        attn_weights = F.softmax(attn_scores, dim=-1)
        o[q_start:q_end] = torch.einsum("qhk,khd->qhd", attn_weights, v_full_hq)


98
class TestTritonAttention(CustomTestCase):
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129

    def _set_all_seeds(self, seed):
        """Set all random seeds for reproducibility."""
        random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    def setUp(self):
        # Set seeds before each test method
        self._set_all_seeds(42)

    def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D):
        dtype = torch.bfloat16

        b_seq_len_prefix = torch.randint(
            1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
        )
        b_seq_len_extend = torch.randint(
            1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
        )
        b_seq_len = b_seq_len_prefix + b_seq_len_extend
        max_len_in_batch = torch.max(b_seq_len, 0)[0].item()

        b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda")
        b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
        b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
        b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda")
        b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
130
131
132
133
134
135
136

        kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
        kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0)
        kv_indices = torch.zeros(
            (b_seq_len_prefix.sum().item(),), dtype=torch.int32, device="cuda"
        )

137
        for i in range(B):
138
139
            kv_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange(
                b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i]
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
            )

        total_token_num = torch.sum(b_seq_len).item()
        extend_token_num = torch.sum(b_seq_len_extend).item()
        k_buffer = torch.empty(
            (total_token_num, H_KV, D), dtype=dtype, device="cuda"
        ).normal_(mean=0.1, std=0.2)
        v_buffer = torch.empty(
            (total_token_num, H_KV, D), dtype=dtype, device="cuda"
        ).normal_(mean=0.1, std=0.2)

        k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
        v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
        q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
        for i in range(B):
            extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
            extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
            extend_start = b_start_loc_extend[i]
            extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
            k_extend[extend_start:extend_end] = k_buffer[
                extend_start_in_buffer:extend_end_in_buffer
            ]
            v_extend[extend_start:extend_end] = v_buffer[
                extend_start_in_buffer:extend_end_in_buffer
            ]
            q_extend[extend_start:extend_end] = torch.empty(
                (b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda"
            ).normal_(mean=0.1, std=0.2)

        o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
170
171
172
        o_extend_mask = torch.empty(
            (extend_token_num, H_Q, D), dtype=dtype, device="cuda"
        )
173
174
175
176
177
178
        o_redundant = torch.empty(
            (extend_token_num, H_Q, D), dtype=dtype, device="cuda"
        )

        b_seq_len_extend = b_seq_len - b_seq_len_prefix
        max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
179
180
181
        qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
        qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)

182
        custom_mask = None
183
        mask_indptr = None
184

185
186
187
188
189
190
191
        extend_attention_fwd(
            q_extend,
            k_extend,
            v_extend,
            o_extend,
            k_buffer,
            v_buffer,
192
193
194
            qo_indptr,
            kv_indptr,
            kv_indices,
195
            custom_mask,
196
            True,
197
            mask_indptr,
198
199
200
201
202
203
204
            max_len_extend,
        )

        b_seq_mask_len = b_seq_len_extend * b_seq_len
        custom_mask = torch.ones(
            (b_seq_mask_len.sum().item(),), dtype=torch.bool, device="cuda"
        )
205
206
        mask_indptr = torch.zeros((B + 1,), dtype=torch.int64, device="cuda")
        mask_indptr[1 : B + 1] = torch.cumsum(b_seq_mask_len[:B], dim=0)
207
208
209
210
211
212
213
214
215
216
217
        for i in range(B):
            causal_mask = (
                torch.tril(
                    torch.ones(b_seq_len_extend[i], b_seq_len_extend[i]), diagonal=0
                )
                == 1
            )
            prefix_mask = torch.ones(
                b_seq_len_extend[i], b_seq_len_prefix[i], dtype=torch.bool
            )
            mask_flatten = torch.cat([prefix_mask, causal_mask], dim=1).flatten()
218
            custom_mask[mask_indptr[i] : mask_indptr[i + 1]] = mask_flatten
219
220
221
222
223
224
225
226
227
228
229
230

        extend_attention_fwd(
            q_extend,
            k_extend,
            v_extend,
            o_extend_mask,
            k_buffer,
            v_buffer,
            qo_indptr,
            kv_indptr,
            kv_indices,
            custom_mask,
231
            True,
232
            mask_indptr,
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
            max_len_extend,
        )

        redundant_attention(
            q_extend,
            o_redundant,
            k_buffer,
            v_buffer,
            b_req_idx,
            b_start_loc,
            b_seq_len,
            b_seq_len_prefix,
            max_len_in_batch,
        )

        self.assertTrue(torch.allclose(o_extend, o_redundant, rtol=1e-2))
249
        self.assertTrue(torch.allclose(o_extend_mask, o_redundant, rtol=1e-2))
250
251
252
253
254
255
256
257
258
259

    def test_extend_attention(self):

        # Define the varying parameter values
        attention_values = [128, 96, 80, 13]

        # Loop through the values and call the method
        for value in attention_values:
            self._test_extend_attention_once(19, 12331, 12, 4, value)

260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
    def _test_extend_attention_sliding_window_once(
        self, B, N_CTX, H_Q, H_KV, D, WINDOW_SIZE
    ):
        dtype = torch.bfloat16

        b_seq_len_prefix = torch.randint(
            1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
        )
        b_seq_len_extend = torch.randint(
            1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
        )
        b_seq_len = b_seq_len_prefix + b_seq_len_extend

        b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
        b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
        b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda")
        b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)

        kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
        kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0)
        kv_indices = torch.zeros(
            (b_seq_len_prefix.sum().item(),), dtype=torch.int32, device="cuda"
        )

        for i in range(B):
            kv_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange(
                b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i]
            )

        total_token_num = torch.sum(b_seq_len).item()
        extend_token_num = torch.sum(b_seq_len_extend).item()
        k_buffer = torch.empty(
            (total_token_num, H_KV, D), dtype=dtype, device="cuda"
        ).normal_(mean=0.1, std=0.2)
        v_buffer = torch.empty(
            (total_token_num, H_KV, D), dtype=dtype, device="cuda"
        ).normal_(mean=0.1, std=0.2)

        k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
        v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
        q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
        for i in range(B):
            extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
            extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
            extend_start = b_start_loc_extend[i]
            extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
            k_extend[extend_start:extend_end] = k_buffer[
                extend_start_in_buffer:extend_end_in_buffer
            ]
            v_extend[extend_start:extend_end] = v_buffer[
                extend_start_in_buffer:extend_end_in_buffer
            ]
            q_extend[extend_start:extend_end] = torch.empty(
                (b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda"
            ).normal_(mean=0.1, std=0.2)

        o_extend_triton = torch.empty(
            (extend_token_num, H_Q, D), dtype=dtype, device="cuda"
        )
        o_extend_torch = torch.empty(
            (extend_token_num, H_Q, D), dtype=dtype, device="cuda"
        )

        b_seq_len_extend = b_seq_len - b_seq_len_prefix
        max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
        qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
        qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)

        extend_attention_fwd(
            q_extend,
            k_extend,
            v_extend,
            o_extend_triton,
            k_buffer,
            v_buffer,
            qo_indptr,
            kv_indptr,
            kv_indices,
            custom_mask=None,
            is_causal=True,
            mask_indptr=None,
            max_len_extend=max_len_extend,
            sliding_window_size=WINDOW_SIZE,
        )

        extend_attention_fwd_torch(
            q_extend,
            k_extend,
            v_extend,
            o_extend_torch,
            k_buffer,
            v_buffer,
            qo_indptr,
            kv_indptr,
            kv_indices,
            WINDOW_SIZE,
        )

        self.assertTrue(
            torch.allclose(o_extend_triton, o_extend_torch, rtol=1e-3, atol=1e-3)
        )

    def test_extend_attention_sliding_window(self):
        window_sizes = [-1, 127]
        for window_size in window_sizes:
            self._test_extend_attention_sliding_window_once(
                19, 12331, 64, 8, 128, window_size
            )

Ke Bao's avatar
Ke Bao committed
369
    def _test_context_attention_once(self, head_dim, is_causal):
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
        # Set up a simple test case
        num_heads = 4
        seq_lens = [8, 12]
        max_seq_len = max(seq_lens)

        # Create random input tensors
        q = torch.randn(sum(seq_lens), num_heads, head_dim, device="cuda")
        k = torch.randn(sum(seq_lens), num_heads, head_dim, device="cuda")
        v = torch.randn(sum(seq_lens), num_heads, head_dim, device="cuda")
        o = torch.zeros(sum(seq_lens), num_heads, head_dim, device="cuda")

        # Create b_start_loc and b_seq_len tensors
        b_start_loc = torch.tensor([0, seq_lens[0]], device="cuda")
        b_seq_len = torch.tensor(seq_lens, device="cuda")

Ke Bao's avatar
Ke Bao committed
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
        context_attention_fwd(
            q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=is_causal
        )

        cu_seq_lens = [0] * (len(seq_lens) + 1)
        for i, seq_len in enumerate(seq_lens):
            cu_seq_lens[i + 1] = cu_seq_lens[i] + seq_len

        for i in range(len(seq_lens)):
            start, end = cu_seq_lens[i], cu_seq_lens[i + 1]
            o_torch = torch.nn.functional.scaled_dot_product_attention(
                q[start:end].permute(1, 0, 2),
                k[start:end].permute(1, 0, 2),
                v[start:end].permute(1, 0, 2),
                is_causal=is_causal,
            ).permute(1, 0, 2)

            cos_sim = torch.nn.functional.cosine_similarity(
                o[start:end].flatten(), o_torch.flatten(), dim=0
            )
            self.assertTrue(cos_sim.item() > 1 - (1e-5))
            self.assertTrue(torch.allclose(o[start:end], o_torch, atol=1e-2))
407
408
409
410
411

    def test_context_attention(self):
        head_dim = [128, 96, 80, 13]

        for dim in head_dim:
Ke Bao's avatar
Ke Bao committed
412
413
            for is_causal in [True, False]:
                self._test_context_attention_once(dim, is_causal)
414
415
416
417
418
419

    def _test_decode_attention_once(self, B, H_Q, H_KV, D):
        dtype = torch.bfloat16
        seq_len = 10  # This represents the number of tokens already in the sequence
        total_tokens = B * seq_len
        sm_scale = 1.0 / (D**0.5)
420
421
        max_kv_splits = 8
        num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device="cuda")
422
423
424
425
426
427
428
429
430
431
432
433
434

        # q represents the new token being generated, one per batch
        q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")

        # k_buffer and v_buffer represent all previous tokens
        k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
        v_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")

        # o will have the same shape as q
        o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda")

        b_seq_len = torch.full((B,), seq_len, device="cuda")

435
436
437
438
        kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
        kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len[:B], dim=0)
        kv_indices = torch.arange(total_tokens, device="cuda")

Ke Bao's avatar
Ke Bao committed
439
        attn_logits = torch.empty(
440
441
442
443
444
445
            (B, H_Q, max_kv_splits, D),
            dtype=torch.float32,
            device="cuda",
        )
        attn_lse = torch.empty(
            (B, H_Q, max_kv_splits),
446
            dtype=torch.float32,
Ke Bao's avatar
Ke Bao committed
447
448
449
            device="cuda",
        )

450
451
452
453
454
        decode_attention_fwd(
            q,
            k_buffer,
            v_buffer,
            o,
455
456
            kv_indptr,
            kv_indices,
457
458
            attn_logits,
            attn_lse,
459
            num_kv_splits,
460
            max_kv_splits,
461
462
463
464
465
            sm_scale,
        )

    def test_decode_attention(self):
        # Here we just to ensure there is no error
466
        # TODO: correctnesss test
467
468
469
470
471
472
473
474
475
476
477
478

        # Test configurations
        configs = [
            (2, 4, 4, 64),  # MHA
            (2, 4, 2, 64),  # GQA
            (2, 4, 4, 80),  # Non-standard head dim
            (2, 4, 4, 13),  # Prime number head dim
        ]

        for B, H_Q, H_KV, D in configs:
            self._test_decode_attention_once(B, H_Q, H_KV, D)

479
    def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V):
Ke Bao's avatar
Ke Bao committed
480
        dtype = torch.bfloat16
481
        seq_len = S  # This represents the number of tokens already in the sequence
Ke Bao's avatar
Ke Bao committed
482
483
        total_tokens = B * seq_len
        sm_scale = 1.0 / (D**0.5)
484
485
        max_kv_splits = 8
        num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device="cuda")
Ke Bao's avatar
Ke Bao committed
486
487
488
489
490
491
492
493
494

        # q represents the new token being generated, one per batch
        q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")

        # k_buffer and v_buffer represent all previous tokens
        k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
        v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device="cuda")

        # o will have the same shape as q
495
496
        o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
        o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
Ke Bao's avatar
Ke Bao committed
497
498
499

        b_seq_len = torch.full((B,), seq_len, device="cuda")

500
501
502
503
        kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
        kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len[:B], dim=0)
        kv_indices = torch.arange(total_tokens, device="cuda")

Ke Bao's avatar
Ke Bao committed
504
        attn_logits = torch.empty(
505
506
507
508
509
510
            (B, H_Q, max_kv_splits, D_V),
            dtype=torch.float32,
            device="cuda",
        )
        attn_lse = torch.empty(
            (B, H_Q, max_kv_splits),
511
            dtype=torch.float32,
Ke Bao's avatar
Ke Bao committed
512
513
514
515
516
517
518
519
            device="cuda",
        )

        decode_attention_fwd_normal(
            q,
            k_buffer,
            v_buffer,
            o,
520
521
            kv_indptr,
            kv_indices,
522
523
            attn_logits,
            attn_lse,
524
            num_kv_splits,
525
            max_kv_splits,
Ke Bao's avatar
Ke Bao committed
526
527
528
            sm_scale,
        )

529
        attn_logits1 = torch.empty(
530
531
532
533
534
535
            (B, H_Q, max_kv_splits, D_V),
            dtype=torch.float32,
            device="cuda",
        )
        attn_lse1 = torch.empty(
            (B, H_Q, max_kv_splits, D_V),
536
537
538
539
            dtype=torch.float32,
            device="cuda",
        )

Ke Bao's avatar
Ke Bao committed
540
541
542
543
544
        decode_attention_fwd_grouped(
            q,
            k_buffer,
            v_buffer,
            o_grouped,
545
546
            kv_indptr,
            kv_indices,
547
548
            attn_logits1,
            attn_lse1,
549
            num_kv_splits,
550
            max_kv_splits,
Ke Bao's avatar
Ke Bao committed
551
552
553
554
555
556
            sm_scale,
        )

        cos_sim = torch.nn.functional.cosine_similarity(
            o.flatten(), o_grouped.flatten(), dim=0
        )
557
        print(cos_sim.item())
Ke Bao's avatar
Ke Bao committed
558
559
560
561
        self.assertTrue(cos_sim.item() > 0.99)
        self.assertTrue(torch.allclose(o, o_grouped, atol=3e-2))

    def test_grouped_decode_attention(self):
562
        seq_lens = [5, 100, 128, 500]
Ke Bao's avatar
Ke Bao committed
563
        configs = [
564
            (2, 16, 16, 64, 64),
Ke Bao's avatar
Ke Bao committed
565
566
567
568
569
570
571
            (2, 16, 1, 64, 64),
            (2, 64, 1, 13, 13),
            (2, 128, 1, 80, 80),
            (2, 128, 2, 512, 512),
            (2, 128, 1, 576, 512),
        ]

572
573
574
        for S in seq_lens:
            for B, H_Q, H_KV, D, D_V in configs:
                self._test_grouped_decode_attention_once(B, S, H_Q, H_KV, D, D_V)
Ke Bao's avatar
Ke Bao committed
575

576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
    def _test_extend_attention_unified_vs_regular_once(self, B, N_CTX, H_Q, H_KV, D):
        """Test that unified kernel produces same results as 2-stage kernel."""
        dtype = torch.bfloat16

        b_seq_len_prefix = torch.randint(
            1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
        )
        b_seq_len_extend = torch.randint(
            1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
        )
        b_seq_len = b_seq_len_prefix + b_seq_len_extend

        b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
        b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
        b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda")
        b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)

        # Setup prefix KV indices
        kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
        kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0)
        kv_indices = torch.zeros(
            (b_seq_len_prefix.sum().item(),), dtype=torch.int64, device="cuda"
        )

        for i in range(B):
            kv_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange(
                b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i]
            )

        total_token_num = torch.sum(b_seq_len).item()
        extend_token_num = torch.sum(b_seq_len_extend).item()
        k_buffer = torch.empty(
            (total_token_num, H_KV, D), dtype=dtype, device="cuda"
        ).normal_(mean=0.1, std=0.2)
        v_buffer = torch.empty(
            (total_token_num, H_KV, D), dtype=dtype, device="cuda"
        ).normal_(mean=0.1, std=0.2)

        k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
        v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
        q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")

        for i in range(B):
            extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
            extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
            extend_start = b_start_loc_extend[i]
            extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
            k_extend[extend_start:extend_end] = k_buffer[
                extend_start_in_buffer:extend_end_in_buffer
            ]
            v_extend[extend_start:extend_end] = v_buffer[
                extend_start_in_buffer:extend_end_in_buffer
            ]
            q_extend[extend_start:extend_end] = torch.empty(
                (b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda"
            ).normal_(mean=0.1, std=0.2)

        # Setup for extend attention
        max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
        qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
        qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)

        # Run 2-stage kernel
        o_regular = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
        extend_attention_fwd(
            q_extend,
            k_extend,
            v_extend,
            o_regular,
            k_buffer,
            v_buffer,
            qo_indptr,
            kv_indptr,
            kv_indices,
            custom_mask=None,
            is_causal=True,
            mask_indptr=None,
            max_len_extend=max_len_extend,
        )

        # Build unified KV indices
        extend_kv_indices = torch.arange(
            total_token_num - extend_token_num,
            total_token_num,
            dtype=torch.int64,
            device="cuda",
        )
        extend_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
        extend_start_loc[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)

        unified_kv_indptr, unified_kv_indices, prefix_lens = build_unified_kv_indices(
            kv_indptr,
            kv_indices,
            extend_start_loc,
            b_seq_len_extend,
            extend_kv_indices,
            B,
        )

        # Run unified kernel
        o_unified = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
        extend_attention_fwd_unified(
            q_extend,
            o_unified,
            k_buffer,
            v_buffer,
            qo_indptr,
            unified_kv_indptr,
            unified_kv_indices,
            prefix_lens,
            max_len_extend=max_len_extend,
            custom_mask=None,
            mask_indptr=None,
            sm_scale=None,
            logit_cap=0.0,
            is_causal=True,
        )

        # Compare results
        self.assertTrue(
            torch.allclose(o_regular, o_unified, rtol=0.15, atol=0.15),
            f"Unified kernel output differs from 2-stage kernel. "
            f"Max diff: {(o_regular - o_unified).abs().max()}",
        )

    def test_extend_attention_unified_vs_regular(self):
        """Test unified kernel matches 2-stage kernel across different configs."""
        configs = [
            (4, 512, 32, 8, 128),  # Standard config
            (2, 2048, 32, 8, 128),  # Long sequence (test 2048 specifically)
            (8, 256, 64, 8, 80),  # Non-standard head dim
        ]

        for B, N_CTX, H_Q, H_KV, D in configs:
            with self.subTest(B=B, N_CTX=N_CTX, H_Q=H_Q, H_KV=H_KV, D=D):
                self._test_extend_attention_unified_vs_regular_once(
                    B, N_CTX, H_Q, H_KV, D
                )

    def test_build_unified_kv_indices(self):
        """Test build_unified_kv_indices correctness."""
        B = 4
        dtype = torch.int64
        device = "cuda"

        # Setup test data
        prefix_lens = torch.tensor([10, 20, 15, 25], dtype=torch.int32, device=device)
        extend_lens = torch.tensor([5, 3, 7, 4], dtype=torch.int32, device=device)

        # Build prefix indices
        prefix_kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)
        prefix_kv_indptr[1:] = torch.cumsum(prefix_lens, dim=0)
        prefix_kv_indices = torch.arange(
            prefix_lens.sum().item(), dtype=dtype, device=device
        )

        # Build extend indices
        extend_start_loc = torch.zeros((B,), dtype=torch.int32, device=device)
        extend_start_loc[1:] = torch.cumsum(extend_lens[:-1], dim=0)
        extend_kv_indices = torch.arange(
            prefix_lens.sum().item(),
            prefix_lens.sum().item() + extend_lens.sum().item(),
            dtype=dtype,
            device=device,
        )

        # Build unified indices
        unified_kv_indptr, unified_kv_indices, returned_prefix_lens = (
            build_unified_kv_indices(
                prefix_kv_indptr,
                prefix_kv_indices,
                extend_start_loc,
                extend_lens,
                extend_kv_indices,
                B,
            )
        )

        # Verify unified_kv_indptr
        expected_lens = prefix_lens + extend_lens
        expected_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)
        expected_indptr[1:] = torch.cumsum(expected_lens, dim=0)
        self.assertTrue(torch.equal(unified_kv_indptr, expected_indptr))

        # Verify prefix_lens
        self.assertTrue(torch.equal(returned_prefix_lens, prefix_lens))

        # Verify unified_kv_indices structure
        for i in range(B):
            start_idx = int(unified_kv_indptr[i])
            end_idx = int(unified_kv_indptr[i + 1])
            prefix_len = int(prefix_lens[i])
            extend_len = int(extend_lens[i])

            # Check that prefix and extend are concatenated correctly
            unified_seq = unified_kv_indices[start_idx:end_idx]
            self.assertEqual(len(unified_seq), prefix_len + extend_len)

774
775
776

if __name__ == "__main__":
    unittest.main()