"test/verify/test_add_gelu_half.cpp" did not exist on "c5d8c71c89d57e86f2e4484a69a65496ca9e37ee"
test_triton_attention_kernels.py 13.2 KB
Newer Older
1
2
3
4
5
import random
import unittest

import torch

Ke Bao's avatar
Ke Bao committed
6
7
8
9
10
from sglang.srt.layers.attention.triton_ops.decode_attention import (
    decode_attention_fwd,
    decode_attention_fwd_grouped,
    decode_attention_fwd_normal,
)
11
from sglang.srt.layers.attention.triton_ops.extend_attention import (
12
13
14
    extend_attention_fwd,
    redundant_attention,
)
15
16
17
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
    context_attention_fwd,
)
18
from sglang.test.test_utils import CustomTestCase
19
20


21
class TestTritonAttention(CustomTestCase):
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52

    def _set_all_seeds(self, seed):
        """Set all random seeds for reproducibility."""
        random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    def setUp(self):
        # Set seeds before each test method
        self._set_all_seeds(42)

    def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D):
        dtype = torch.bfloat16

        b_seq_len_prefix = torch.randint(
            1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
        )
        b_seq_len_extend = torch.randint(
            1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
        )
        b_seq_len = b_seq_len_prefix + b_seq_len_extend
        max_len_in_batch = torch.max(b_seq_len, 0)[0].item()

        b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda")
        b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
        b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
        b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda")
        b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
53
54
55
56
57
58
59

        kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
        kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0)
        kv_indices = torch.zeros(
            (b_seq_len_prefix.sum().item(),), dtype=torch.int32, device="cuda"
        )

60
        for i in range(B):
61
62
            kv_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange(
                b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i]
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
            )

        total_token_num = torch.sum(b_seq_len).item()
        extend_token_num = torch.sum(b_seq_len_extend).item()
        k_buffer = torch.empty(
            (total_token_num, H_KV, D), dtype=dtype, device="cuda"
        ).normal_(mean=0.1, std=0.2)
        v_buffer = torch.empty(
            (total_token_num, H_KV, D), dtype=dtype, device="cuda"
        ).normal_(mean=0.1, std=0.2)

        k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
        v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
        q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
        for i in range(B):
            extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
            extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
            extend_start = b_start_loc_extend[i]
            extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
            k_extend[extend_start:extend_end] = k_buffer[
                extend_start_in_buffer:extend_end_in_buffer
            ]
            v_extend[extend_start:extend_end] = v_buffer[
                extend_start_in_buffer:extend_end_in_buffer
            ]
            q_extend[extend_start:extend_end] = torch.empty(
                (b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda"
            ).normal_(mean=0.1, std=0.2)

        o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
93
94
95
        o_extend_mask = torch.empty(
            (extend_token_num, H_Q, D), dtype=dtype, device="cuda"
        )
96
97
98
99
100
101
        o_redundant = torch.empty(
            (extend_token_num, H_Q, D), dtype=dtype, device="cuda"
        )

        b_seq_len_extend = b_seq_len - b_seq_len_prefix
        max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
102
103
104
        qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
        qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)

105
        custom_mask = None
106
        mask_indptr = None
107

108
109
110
111
112
113
114
        extend_attention_fwd(
            q_extend,
            k_extend,
            v_extend,
            o_extend,
            k_buffer,
            v_buffer,
115
116
117
            qo_indptr,
            kv_indptr,
            kv_indices,
118
            custom_mask,
119
            mask_indptr,
120
121
122
123
124
125
126
            max_len_extend,
        )

        b_seq_mask_len = b_seq_len_extend * b_seq_len
        custom_mask = torch.ones(
            (b_seq_mask_len.sum().item(),), dtype=torch.bool, device="cuda"
        )
127
128
        mask_indptr = torch.zeros((B + 1,), dtype=torch.int64, device="cuda")
        mask_indptr[1 : B + 1] = torch.cumsum(b_seq_mask_len[:B], dim=0)
129
130
131
132
133
134
135
136
137
138
139
        for i in range(B):
            causal_mask = (
                torch.tril(
                    torch.ones(b_seq_len_extend[i], b_seq_len_extend[i]), diagonal=0
                )
                == 1
            )
            prefix_mask = torch.ones(
                b_seq_len_extend[i], b_seq_len_prefix[i], dtype=torch.bool
            )
            mask_flatten = torch.cat([prefix_mask, causal_mask], dim=1).flatten()
140
            custom_mask[mask_indptr[i] : mask_indptr[i + 1]] = mask_flatten
141
142
143
144
145
146
147
148
149
150
151
152

        extend_attention_fwd(
            q_extend,
            k_extend,
            v_extend,
            o_extend_mask,
            k_buffer,
            v_buffer,
            qo_indptr,
            kv_indptr,
            kv_indices,
            custom_mask,
153
            mask_indptr,
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
            max_len_extend,
        )

        redundant_attention(
            q_extend,
            o_redundant,
            k_buffer,
            v_buffer,
            b_req_idx,
            b_start_loc,
            b_seq_len,
            b_seq_len_prefix,
            max_len_in_batch,
        )

        self.assertTrue(torch.allclose(o_extend, o_redundant, rtol=1e-2))
170
        self.assertTrue(torch.allclose(o_extend_mask, o_redundant, rtol=1e-2))
171
172
173
174
175
176
177
178
179
180

    def test_extend_attention(self):

        # Define the varying parameter values
        attention_values = [128, 96, 80, 13]

        # Loop through the values and call the method
        for value in attention_values:
            self._test_extend_attention_once(19, 12331, 12, 4, value)

Ke Bao's avatar
Ke Bao committed
181
    def _test_context_attention_once(self, head_dim, is_causal):
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
        # Set up a simple test case
        num_heads = 4
        seq_lens = [8, 12]
        max_seq_len = max(seq_lens)

        # Create random input tensors
        q = torch.randn(sum(seq_lens), num_heads, head_dim, device="cuda")
        k = torch.randn(sum(seq_lens), num_heads, head_dim, device="cuda")
        v = torch.randn(sum(seq_lens), num_heads, head_dim, device="cuda")
        o = torch.zeros(sum(seq_lens), num_heads, head_dim, device="cuda")

        # Create b_start_loc and b_seq_len tensors
        b_start_loc = torch.tensor([0, seq_lens[0]], device="cuda")
        b_seq_len = torch.tensor(seq_lens, device="cuda")

Ke Bao's avatar
Ke Bao committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
        context_attention_fwd(
            q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=is_causal
        )

        cu_seq_lens = [0] * (len(seq_lens) + 1)
        for i, seq_len in enumerate(seq_lens):
            cu_seq_lens[i + 1] = cu_seq_lens[i] + seq_len

        for i in range(len(seq_lens)):
            start, end = cu_seq_lens[i], cu_seq_lens[i + 1]
            o_torch = torch.nn.functional.scaled_dot_product_attention(
                q[start:end].permute(1, 0, 2),
                k[start:end].permute(1, 0, 2),
                v[start:end].permute(1, 0, 2),
                is_causal=is_causal,
            ).permute(1, 0, 2)

            cos_sim = torch.nn.functional.cosine_similarity(
                o[start:end].flatten(), o_torch.flatten(), dim=0
            )
            self.assertTrue(cos_sim.item() > 1 - (1e-5))
            self.assertTrue(torch.allclose(o[start:end], o_torch, atol=1e-2))
219
220
221
222
223

    def test_context_attention(self):
        head_dim = [128, 96, 80, 13]

        for dim in head_dim:
Ke Bao's avatar
Ke Bao committed
224
225
            for is_causal in [True, False]:
                self._test_context_attention_once(dim, is_causal)
226
227
228
229
230
231

    def _test_decode_attention_once(self, B, H_Q, H_KV, D):
        dtype = torch.bfloat16
        seq_len = 10  # This represents the number of tokens already in the sequence
        total_tokens = B * seq_len
        sm_scale = 1.0 / (D**0.5)
232
233
        max_kv_splits = 8
        num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device="cuda")
234
235
236
237
238
239
240
241
242
243
244
245
246

        # q represents the new token being generated, one per batch
        q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")

        # k_buffer and v_buffer represent all previous tokens
        k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
        v_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")

        # o will have the same shape as q
        o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda")

        b_seq_len = torch.full((B,), seq_len, device="cuda")

247
248
249
250
        kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
        kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len[:B], dim=0)
        kv_indices = torch.arange(total_tokens, device="cuda")

Ke Bao's avatar
Ke Bao committed
251
        attn_logits = torch.empty(
252
253
254
255
256
257
            (B, H_Q, max_kv_splits, D),
            dtype=torch.float32,
            device="cuda",
        )
        attn_lse = torch.empty(
            (B, H_Q, max_kv_splits),
258
            dtype=torch.float32,
Ke Bao's avatar
Ke Bao committed
259
260
261
            device="cuda",
        )

262
263
264
265
266
        decode_attention_fwd(
            q,
            k_buffer,
            v_buffer,
            o,
267
268
            kv_indptr,
            kv_indices,
269
270
            attn_logits,
            attn_lse,
271
            num_kv_splits,
272
            max_kv_splits,
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
            sm_scale,
        )

    def test_decode_attention(self):
        # Here we just to ensure there is no error
        # TODO: correctnesss test

        # Test configurations
        configs = [
            (2, 4, 4, 64),  # MHA
            (2, 4, 2, 64),  # GQA
            (2, 4, 4, 80),  # Non-standard head dim
            (2, 4, 4, 13),  # Prime number head dim
        ]

        for B, H_Q, H_KV, D in configs:
            self._test_decode_attention_once(B, H_Q, H_KV, D)

291
    def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V):
Ke Bao's avatar
Ke Bao committed
292
        dtype = torch.bfloat16
293
        seq_len = S  # This represents the number of tokens already in the sequence
Ke Bao's avatar
Ke Bao committed
294
295
        total_tokens = B * seq_len
        sm_scale = 1.0 / (D**0.5)
296
297
        max_kv_splits = 8
        num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device="cuda")
Ke Bao's avatar
Ke Bao committed
298
299
300
301
302
303
304
305
306

        # q represents the new token being generated, one per batch
        q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")

        # k_buffer and v_buffer represent all previous tokens
        k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
        v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device="cuda")

        # o will have the same shape as q
307
308
        o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
        o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
Ke Bao's avatar
Ke Bao committed
309
310
311

        b_seq_len = torch.full((B,), seq_len, device="cuda")

312
313
314
315
        kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
        kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len[:B], dim=0)
        kv_indices = torch.arange(total_tokens, device="cuda")

Ke Bao's avatar
Ke Bao committed
316
        attn_logits = torch.empty(
317
318
319
320
321
322
            (B, H_Q, max_kv_splits, D_V),
            dtype=torch.float32,
            device="cuda",
        )
        attn_lse = torch.empty(
            (B, H_Q, max_kv_splits),
323
            dtype=torch.float32,
Ke Bao's avatar
Ke Bao committed
324
325
326
327
328
329
330
331
            device="cuda",
        )

        decode_attention_fwd_normal(
            q,
            k_buffer,
            v_buffer,
            o,
332
333
            kv_indptr,
            kv_indices,
334
335
            attn_logits,
            attn_lse,
336
            num_kv_splits,
337
            max_kv_splits,
Ke Bao's avatar
Ke Bao committed
338
339
340
            sm_scale,
        )

341
        attn_logits1 = torch.empty(
342
343
344
345
346
347
            (B, H_Q, max_kv_splits, D_V),
            dtype=torch.float32,
            device="cuda",
        )
        attn_lse1 = torch.empty(
            (B, H_Q, max_kv_splits, D_V),
348
349
350
351
            dtype=torch.float32,
            device="cuda",
        )

Ke Bao's avatar
Ke Bao committed
352
353
354
355
356
        decode_attention_fwd_grouped(
            q,
            k_buffer,
            v_buffer,
            o_grouped,
357
358
            kv_indptr,
            kv_indices,
359
360
            attn_logits1,
            attn_lse1,
361
            num_kv_splits,
362
            max_kv_splits,
Ke Bao's avatar
Ke Bao committed
363
364
365
366
367
368
            sm_scale,
        )

        cos_sim = torch.nn.functional.cosine_similarity(
            o.flatten(), o_grouped.flatten(), dim=0
        )
369
        print(cos_sim.item())
Ke Bao's avatar
Ke Bao committed
370
371
372
373
        self.assertTrue(cos_sim.item() > 0.99)
        self.assertTrue(torch.allclose(o, o_grouped, atol=3e-2))

    def test_grouped_decode_attention(self):
374
        seq_lens = [5, 100, 128, 500]
Ke Bao's avatar
Ke Bao committed
375
        configs = [
376
            (2, 16, 16, 64, 64),
Ke Bao's avatar
Ke Bao committed
377
378
379
380
381
382
383
            (2, 16, 1, 64, 64),
            (2, 64, 1, 13, 13),
            (2, 128, 1, 80, 80),
            (2, 128, 2, 512, 512),
            (2, 128, 1, 576, 512),
        ]

384
385
386
        for S in seq_lens:
            for B, H_Q, H_KV, D, D_V in configs:
                self._test_grouped_decode_attention_once(B, S, H_Q, H_KV, D, D_V)
Ke Bao's avatar
Ke Bao committed
387

388
389
390

if __name__ == "__main__":
    unittest.main()