sparse_utils.py 24.3 KB
Newer Older
zhangshao's avatar
zhangshao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
"""
Copyright (c) 2025 by SpargeAttn team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import torch
import torch.nn.functional as F
from einops import rearrange, repeat
import triton
import triton.language as tl
from torch import Tensor


def precision_metric(quant_o, fa2_o, verbose=True, round_num=4): 
    if quant_o.shape[-2] > 200000:
        quant_o, fa2_o = quant_o.cpu(), fa2_o.cpu()
    x, xx = quant_o.float(), fa2_o.float() 
    sim = F.cosine_similarity(x.reshape(1, -1), xx.reshape(1, -1)).item()
    l1 =   ( (x - xx).abs().sum() / xx.abs().sum() ).item()
    rmse = torch.sqrt(torch.mean((x -xx) ** 2)).item()
    sim = round(sim, round_num)
    l1 = round(l1, round_num)
    rmse = round(rmse, round_num)
    if verbose: print(f'Cossim: {sim:.6f}, L1: {l1:.6f}, RMSE:{rmse:.6f}')
    return {"Cossim": sim, "L1": l1, "RMSE": rmse}

def hyperparameter_check(hyper, H, device):
    if type(hyper) == float or type(hyper) == int:
        hyper = torch.full((H,), float(hyper), device=device)
    elif isinstance(hyper, Tensor):
        assert len(hyper.shape) <= 1, "Hyperparameter tensor must be 1D"
        if len(hyper.shape) == 0:
            hyper = torch.full((H,), hyper.item(), device=device)
        assert hyper.numel() == H, f"Hyperparameter tensor must have {H} elements, but has {hyper.numel()}"
        hyper = hyper.to(device)
    else:
        print(hyper)
        raise ValueError("Hyperparameter must be a float or a tensor")
    return hyper



@triton.jit
def triton_block_map_to_lut_kernel(map_ptr, lut_ptr, valid_block_num_ptr, num_block_k):
    b, h, q = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    B, H, Q = tl.num_programs(0), tl.num_programs(1), tl.num_programs(2)
    valid_block_num = 0

    map_ptr = map_ptr + b * H * Q * num_block_k + h * Q * num_block_k + q * num_block_k
    lut_ptr = lut_ptr + b * H * Q * num_block_k + h * Q * num_block_k + q * num_block_k
    valid_block_num_ptr = valid_block_num_ptr + b * H * Q + h * Q + q

    valid_block_num = 0
    prev_block = 0

    for i in range(num_block_k):
        cur_block = tl.load(map_ptr + i)
        if cur_block:
            tl.store(lut_ptr + valid_block_num, i - prev_block)
            valid_block_num += 1
            prev_block = i

    tl.store(valid_block_num_ptr, valid_block_num)

@triton.jit
def triton_block_map_to_offset_kernel(map_ptr, offset_ptr, block_count_ptr, num_block_k):
    """
    Convert block_map to block_offset format using Triton.

    Args:
        map_ptr: Pointer to block_map (B, H, Q, K) boolean tensor
        offset_ptr: Pointer to output block_offset (B, H, Q, K) int32 tensor
        block_count_ptr: Pointer to output block_count (B, H, Q) int32 tensor
        num_block_k: Number of blocks in K dimension
    """
    b, h, q = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    B, H, Q = tl.num_programs(0), tl.num_programs(1), tl.num_programs(2)

    map_ptr = map_ptr + b * H * Q * num_block_k + h * Q * num_block_k + q * num_block_k
    offset_ptr = offset_ptr + b * H * Q * num_block_k + h * Q * num_block_k + q * num_block_k
    block_count_ptr = block_count_ptr + b * H * Q + h * Q + q

    block_count = 0

    # Collect absolute indices of selected blocks
    for i in range(num_block_k):
        cur_block = tl.load(map_ptr + i)
        if cur_block:
            tl.store(offset_ptr + block_count, i)
            block_count += 1

    tl.store(block_count_ptr, block_count)

def block_map_lut_triton(block_map):
    assert block_map.dim() == 4
    assert block_map.is_contiguous()

    B, H, Q, K = block_map.shape
    lut = torch.zeros((B, H, Q, K), dtype=torch.int32, device=block_map.device)
    valid_block_num = torch.zeros((B, H, Q), dtype=torch.int32, device=block_map.device)

    grid = (B, H, Q)
    triton_block_map_to_lut_kernel[grid](block_map, lut, valid_block_num, K)

    return lut, valid_block_num

def block_map_to_block_offset_triton(block_map):
    """
    Convert block_map to block_offset format using Triton.

    Args:
        block_map: (B, H, Q, K) boolean tensor

    Returns:
        block_offset: (B, H, Q, K) int32 tensor - absolute indices of selected blocks
                      (invalid positions filled with 10000000 for consistency with PyTorch version)
        block_count: (B, H, Q) int32 tensor - number of selected blocks per row
    """
    assert block_map.dim() == 4
    assert block_map.is_contiguous()

    B, H, Q, K = block_map.shape
    # Initialize with large value for invalid positions (consistent with PyTorch version)
    block_offset = torch.full((B, H, Q, K), 10000000, dtype=torch.int32, device=block_map.device)
    block_count = torch.zeros((B, H, Q), dtype=torch.int32, device=block_map.device)

    grid = (B, H, Q)
    triton_block_map_to_offset_kernel[grid](block_map, block_offset, block_count, K)

    return block_offset, block_count

@triton.jit
def qk_quantize(
    # Pointers
    x_ptr,
    xm_ptr,
    x_quant_ptr,
    scale_ptr,
    # Constexpr dimensions
    N: tl.constexpr,
    D: tl.constexpr,
    BS: tl.constexpr,
    fuse_mean: tl.constexpr
):
    """
    Triton kernel to perform per-block quantization of a tensor X to int8.
    It loads a block of X, optionally subtracts a mean vector, then calculates
    a scaling factor for the block and quantizes the data to int8.

    Grid: (B, H, NB)
        B: Batch size
        H: Number of heads
        NB: Number of blocks in the N dimension (N // BS)
    """
    # 1. Get program IDs to identify the current block
    b, h, nb = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    B, H, NB = tl.num_programs(0), tl.num_programs(1), tl.num_programs(2)

    # 2. Calculate pointers for the input block X
    block_offset = b * H * N * D + h * N * D + nb * BS * D
    x_ptrs = x_ptr + block_offset + tl.arange(0, BS)[:, None] * D + tl.arange(0, D)[None, :]
    
    # Create a mask to handle the last block if N is not a multiple of BS
    xmask = (nb * BS + tl.arange(0, BS)[:, None]) < N
    
    # Load the input block
    x = tl.load(x_ptrs, mask=xmask, other=0.0)

    # 3. (Optional) Subtract the mean if fuse_mean is enabled
    if fuse_mean:
        xm_ptrs = xm_ptr + b * H * D + h * D + tl.arange(0, D)
        x_mean = tl.load(xm_ptrs)
        x -= x_mean
        # Re-apply mask to zero out padded values after subtraction
        x = tl.where(xmask, x, 0.0)

    # 4. Perform quantization
    # Convert to float32 for stable calculations
    x_fp32 = x.to(tl.float32)

    # Calculate the scale: max(abs(x)) / 127.0
    # The scale is per-block
    scale = tl.max(tl.abs(x_fp32)) / 127.0
    # Add a small epsilon to avoid division by zero
    scale += 1e-7

    # Quantize to int8: (x / scale) and round to nearest integer
    x_int8 = x_fp32 / scale
    # Round to nearest: add 0.5 for positive, -0.5 for negative
    x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
    x_int8 = x_int8.to(tl.int8)

    # 5. Calculate output pointers and store the results
    # Pointers for the quantized output tensor
    x_quant_ptrs = x_quant_ptr + block_offset + tl.arange(0, BS)[:, None] * D + tl.arange(0, D)[None, :]
    # Pointer for the scale value of this block
    scale_ptrs = scale_ptr + b * H * NB + h * NB + nb

    # Store the quantized int8 values
    tl.store(x_quant_ptrs, x_int8, mask=xmask)
    # Store the scale value
    tl.store(scale_ptrs, scale)

@triton.jit
def triton_bmm_pool_sim_simmean_fuse_quant(
    x_ptr,
    xm_ptr,
    pool_ptr,
    sim_ptr,
    x_quant_ptr,
    scale_ptr,
    simthreshd1,
    N: tl.constexpr,
    D: tl.constexpr,
    BS: tl.constexpr,
    fuse_mean: tl.constexpr
):
    b, h, nb = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    B, H, NB = tl.num_programs(0), tl.num_programs(1), tl.num_programs(2)

    block_offset = b * H * N * D + h * N * D + nb * BS * D
    xmask = (nb*BS + tl.arange(0, BS)[:, None]) < N
    x_ptrs = x_ptr + block_offset + tl.arange(0, BS)[:, None] * D + tl.arange(0, D)[None, :]
    x = tl.load(x_ptrs, mask = xmask)
    BS_ = BS if (N - nb*BS) >= BS else (N - nb*BS)

    if fuse_mean:
        xm_ptrs = xm_ptr + b * H * D + h * D + tl.arange(0, D)
        x_mean = tl.load(xm_ptrs)
        x -= x_mean
        x = tl.where(xmask, x, 0)

    cur_h1 = tl.load(simthreshd1 + h)
    x_fp32 = x.to(tl.float32)

    pool = (tl.sum(x_fp32, axis=0) / BS_)
    x_norm = tl.sqrt(tl.sum(x_fp32 * x_fp32, axis=1, keep_dims=True))
    x = (x / x_norm).to(tl.float16)  # norm at D dim
    
    grams = tl.dot(x, tl.trans(x))
    sum_value = tl.sum(grams).to(tl.float32)
    cur_sim = (sum_value / (BS_ * BS_)) > cur_h1

    pool_block_offset = b * H * NB * D + h * NB * D + nb * D
    tl.store(pool_ptr + pool_block_offset + tl.arange(0, D), pool)
    sim_offset = b * H * NB + h * NB + nb
    tl.store(sim_ptr + sim_offset, cur_sim)

    scale = tl.max(tl.abs(x_fp32)) / 127.
    scale += 0.0000001
    x_int8 = x_fp32 / scale
    x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
    x_int8 = x_int8.to(tl.int8)
    x_quant_ptrs = x_quant_ptr + block_offset + tl.arange(0, BS)[:, None] * D + tl.arange(0, D)[None, :]
    scale_ptrs = scale_ptr + b * H * NB + h * NB + nb
    tl.store(x_quant_ptrs, x_int8, mask = xmask)
    tl.store(scale_ptrs, scale)

@triton.jit
def triton_bmm_pool_sim_simmean(x_ptr, pool_ptr, sim_ptr, simthreshd1, N: tl.constexpr, D: tl.constexpr, BS: tl.constexpr):
    b, h, nb = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    B, H, NB = tl.num_programs(0), tl.num_programs(1), tl.num_programs(2)

    block_offset = b * H * N * D + h * N * D + nb * BS * D
    xmask = (nb*BS + tl.arange(0, BS)[:, None]) < N
    x_ptrs = x_ptr + block_offset + tl.arange(0, BS)[:, None] * D + tl.arange(0, D)[None, :]
    x = tl.load(x_ptrs, mask = xmask)
    BS_ = BS if (N - nb*BS) >= BS else (N - nb*BS)

    cur_h1 = tl.load(simthreshd1 + h)
    x_fp32 = x.to(tl.float32)
    pool = (tl.sum(x_fp32, axis=0) / BS_)
    x_norm = tl.sqrt(tl.sum(x_fp32 * x_fp32, axis=1, keep_dims=True))
    x = (x / x_norm).to(tl.float16)  # norm at D dim
    
    grams = tl.dot(x, tl.trans(x))
    sum_value = tl.sum(grams).to(tl.float32)
    cur_sim = (sum_value / (BS_ * BS_)) > cur_h1

    pool_block_offset = b * H * NB * D + h * NB * D + nb * D
    tl.store(pool_ptr + pool_block_offset + tl.arange(0, D), pool)
    sim_offset = b * H * NB + h * NB + nb
    tl.store(sim_ptr + sim_offset, cur_sim)
    
    
def get_pool_sim_triton_simmean(x, block_size, simthreshd1):
    x = x.contiguous()
    B, H, N, D = x.shape
    nblock = (N + block_size - 1) // block_size  # Number of blocks per feature map
    pool = torch.empty((B, H, nblock, D), device=x.device, dtype=x.dtype)
    sim_blocks = torch.empty((B, H, nblock), device=x.device, dtype=torch.bool)
    grid = (B, H, nblock)
    # Launch kernel
    triton_bmm_pool_sim_simmean[grid](x, pool, sim_blocks, simthreshd1, N=N, D=D, BS=block_size)
    return pool, sim_blocks
 
#todo(xingyang): wrapper for tensor quantization
def get_quant(x, x_mean, block_size):
    x = x.contiguous()
    B, H, N, D = x.shape
    nblock = (N + block_size - 1) // block_size
    x_quant = torch.empty(x.shape, device=x.device, dtype=torch.int8)
    x_scale = torch.empty((B, H, nblock), device=x.device, dtype=torch.float32)
    grid = (B, H, nblock)
    qk_quantize[grid](x, x_mean, x_quant, x_scale, N=N, D=D, BS=block_size, fuse_mean=(True if x_mean is not None else False))
    return x_quant, x_scale

def get_vanilla_qk_quant(q, k, km=None, BLKQ=128, BLKK=64):
    q_int8, q_scale = get_quant(q, None, BLKQ)
    k_int8, k_scale = get_quant(k, km, BLKK)
    return q_int8, q_scale, k_int8, k_scale

def get_pool_sim_triton_simmean_fuse_quant(x, x_mean, block_size, simthreshd1):
    x = x.contiguous()
    B, H, N, D = x.shape
    nblock = (N + block_size - 1) // block_size  # Number of blocks per feature map
    pool = torch.empty((B, H, nblock, D), device=x.device, dtype=x.dtype)
    sim_blocks = torch.empty((B, H, nblock), device=x.device, dtype=torch.bool)
    x_quant = torch.empty(x.shape, device=x.device, dtype=torch.int8)
    x_scale = torch.empty((B, H, nblock), device=x.device, dtype=torch.float32)
    grid = (B, H, nblock)
    triton_bmm_pool_sim_simmean_fuse_quant[grid](x, x_mean, pool, sim_blocks, x_quant, x_scale, simthreshd1, N=N, D=D, BS=block_size, fuse_mean=(True if x_mean is not None else False))
    return pool, sim_blocks, x_quant, x_scale

@triton.jit
def triton_fill_block_map_kernel(final_map, num_to_select, sorted_indices, NK: tl.constexpr):
    b, h, q = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    B, H, Q = tl.num_programs(0), tl.num_programs(1), tl.num_programs(2)
    cur_num_to_select = tl.load(num_to_select + b * H * Q + h * Q + q)
    cur_sorted_idx_ptr = sorted_indices + b * H * Q * NK + h * Q * NK + q * NK
    cur_final_map_ptr = final_map + b * H * Q * NK + h * Q * NK + q * NK
    cur_num_to_select = (cur_num_to_select + 1) if cur_num_to_select == 0 else cur_num_to_select
    for i in range(cur_num_to_select):
        cur_idx = tl.load(cur_sorted_idx_ptr + i)
        tl.store(cur_final_map_ptr + cur_idx, 1)
    

def fill_block_map_triton(final_map, num_to_select, sorted_indices):
    final_map = final_map.contiguous()
    num_to_select = num_to_select.contiguous()
    sorted_indices = sorted_indices.contiguous()
    B, H, Q, K = final_map.shape
    grid = (B, H, Q)
    triton_fill_block_map_kernel[grid](final_map, num_to_select, sorted_indices, K)
    return final_map

@triton.jit
def triton_fill_causal_mask(mask, BqdivBk):
    q, k = tl.program_id(0), tl.program_id(1)
    Q, K = tl.num_programs(0), tl.num_programs(1)
    if k >= (q + 1) * BqdivBk:
        tl.store(mask + q * K + k, 0)
    else:
        tl.store(mask + q * K + k, 1)

def fill_causal_mask_triton(mask, BqdivBk:float):
    assert mask.dim() == 2
    triton_fill_causal_mask[mask.shape](mask, BqdivBk)
    return mask


def get_block_map_meansim(q, k, is_causal=False, BLKQ=128, BLKK=64, simthreshd1=0.1, cdfthreshd=0.9, topk=None, is_sparse=True, return_lut=False, return_block_offset=False, attention_sink=False):
    assert (cdfthreshd is None and topk is not None) \
        or (cdfthreshd is not None and topk is None), "Only one of cdfthreshd and topk can be set."
    assert not (return_lut and return_block_offset), "Only one of return_lut and return_block_offset can be True."

    Headnum = q.size(1)
    simthreshd1 = hyperparameter_check(simthreshd1, Headnum, q.device)
    if cdfthreshd is not None:
        cdfthreshd = hyperparameter_check(cdfthreshd, Headnum, q.device)
    if topk is not None:
        topk = hyperparameter_check(topk, Headnum, q.device)
    nq = (q.shape[-2] + BLKQ - 1) // BLKQ
    nk = (k.shape[-2] + BLKK - 1) // BLKK
    pooled_qblocks, sim_qblocks = get_pool_sim_triton_simmean(q, BLKQ, simthreshd1)
    pooled_kblocks, sim_kblocks = get_pool_sim_triton_simmean(k, BLKK, simthreshd1)

    sim_kblocks = sim_kblocks.unsqueeze(-2).expand(-1, -1, nq, -1)  # faster than repeat
    sim_qblocks = sim_qblocks.unsqueeze(-1).expand(-1, -1, -1, nk)
    pooled_score = pooled_qblocks @ pooled_kblocks.transpose(-1, -2) * q.shape[-1] ** -0.5
    pooled_score[~sim_kblocks] = -torch.inf
    if is_causal:
        nq = pooled_qblocks.shape[-2]
        nk = pooled_kblocks.shape[-2]
        empty_mask = torch.empty(nq, nk, device=q.device, dtype=torch.bool)
        causal_mask = fill_causal_mask_triton(empty_mask, BLKQ / BLKK)
        pooled_score = pooled_score.masked_fill(~causal_mask[None, None, ...], -torch.inf)
    pooled_score = pooled_score.softmax(-1)
    sorted_score = torch.sort(pooled_score, dim=-1, descending=True)
    cdf = torch.cumsum(sorted_score.values, dim=-1)
    B, H, Q, K = cdf.shape
    if cdfthreshd is not None:
        cdfthreshd_ts = cdfthreshd.view(1, H, 1, 1)
        cdfthreshd_ts = cdfthreshd_ts.expand(B, -1, Q, 1).contiguous()
        num_to_select = torch.searchsorted(cdf, cdfthreshd_ts, right=True).squeeze(-1)
    else:
        num_to_select = (topk * K).to(torch.int64).view(1, H, 1).expand(B, -1, Q).contiguous()

    final_map = torch.zeros_like(pooled_score, dtype=torch.bool)
    final_map[~sim_kblocks] = 1
    final_map[~sim_qblocks] = 1
    final_map = fill_block_map_triton(final_map, num_to_select, sorted_score.indices)
    if is_causal:
        final_map = final_map * causal_mask[None, None, ...]

    if attention_sink:
        final_map[:, :, :, 0] = 1

    if return_lut:
        lut, valid_block_num = block_map_lut_triton(final_map)
        return lut, valid_block_num
    elif return_block_offset:
        block_offset, block_count = block_map_to_block_offset_triton(final_map)
        return block_offset, block_count
    else:
        return final_map

def get_block_map_meansim_fuse_quant(q, k, km=None, is_causal=False, BLKQ=128, BLKK=64, simthreshd1=0.1, cdfthreshd=0.9, topk=None, is_sparse=True, return_lut=False, return_block_offset=False, attention_sink=False):
    assert (cdfthreshd is None and topk is not None) \
        or (cdfthreshd is not None and topk is None), "Only one of cdfthreshd and topk can be set."
    assert not (return_lut and return_block_offset), "Only one of return_lut and return_block_offset can be True."

    Headnum = q.size(1)
    simthreshd1 = hyperparameter_check(simthreshd1, Headnum, q.device)
    if cdfthreshd is not None:
        cdfthreshd = hyperparameter_check(cdfthreshd, Headnum, q.device)
    if topk is not None:
        topk = hyperparameter_check(topk, Headnum, q.device)
    nq = (q.shape[-2] + BLKQ - 1) // BLKQ
    nk = (k.shape[-2] + BLKK - 1) // BLKK
    pooled_qblocks, sim_qblocks, q_int8, q_scale = get_pool_sim_triton_simmean_fuse_quant(q, None, BLKQ, simthreshd1)
    pooled_kblocks, sim_kblocks, k_int8, k_scale = get_pool_sim_triton_simmean_fuse_quant(k, km, BLKK, simthreshd1)

    sim_kblocks = sim_kblocks.unsqueeze(-2).expand(-1, -1, nq, -1)  # faster than repeat
    sim_qblocks = sim_qblocks.unsqueeze(-1).expand(-1, -1, -1, nk)
    pooled_score = pooled_qblocks @ pooled_kblocks.transpose(-1, -2) * q.shape[-1] ** -0.5
    pooled_score[~sim_kblocks] = -torch.inf
    if is_causal:
        nq = pooled_qblocks.shape[-2]
        nk = pooled_kblocks.shape[-2]
        empty_mask = torch.empty(nq, nk, device=q.device, dtype=torch.bool)
        causal_mask = fill_causal_mask_triton(empty_mask, BLKQ / BLKK)
        pooled_score = pooled_score.masked_fill(~causal_mask[None, None, ...], -torch.inf)
    pooled_score = pooled_score.softmax(-1)
    sorted_score = torch.sort(pooled_score, dim=-1, descending=True)
    cdf = torch.cumsum(sorted_score.values, dim=-1)
    B, H, Q, K = cdf.shape
    if cdfthreshd is not None:
        cdfthreshd_ts = cdfthreshd.view(1, H, 1, 1)
        cdfthreshd_ts = cdfthreshd_ts.expand(B, -1, Q, 1).contiguous()
        num_to_select = torch.searchsorted(cdf, cdfthreshd_ts, right=True).squeeze(-1)
    else:
        num_to_select = (topk * K).to(torch.int64).view(1, H, 1).expand(B, -1, Q).contiguous()

    final_map = torch.zeros_like(pooled_score, dtype=torch.bool)
    final_map[~sim_kblocks] = 1
    final_map[~sim_qblocks] = 1
    final_map = fill_block_map_triton(final_map, num_to_select, sorted_score.indices)
    if is_causal:
        final_map = final_map * causal_mask[None, None, ...]

    if attention_sink:
        final_map[:, :, :, 0] = 1

    if return_lut:
        lut, valid_block_num = block_map_lut_triton(final_map)
        return lut, valid_block_num, q_int8, q_scale, k_int8, k_scale
    elif return_block_offset:
        block_offset, block_count = block_map_to_block_offset_triton(final_map)
        return block_offset, block_count, q_int8, q_scale, k_int8, k_scale
    else:
        return final_map, q_int8, q_scale, k_int8, k_scale


def block_map_to_mask(block_map, BLKQ=128, BLKK=64):
    B, H, x, y = block_map.shape

    expanded_mask = torch.zeros((B, H, x * BLKQ, y * BLKK), dtype=torch.bool, device=block_map.device)
    for i in range(x):
        for j in range(y):
            expanded_mask[..., i * BLKQ: (i + 1) * BLKQ, j * BLKK: (j + 1) * BLKK] = block_map[..., i:i+1, j:j+1]

    return expanded_mask

def block_map_lut(block_map):
    """
    Convert block_map (boolean mask) to LUT format (relative offsets).
    Used by SpargeAttn.

    Args:
        block_map: (B, H, x, y) boolean tensor, True indicates selected blocks

    Returns:
        lut: (B, H, x, y) int32 tensor - relative offsets between selected blocks
        valid_entry_num: (B, H, x) int32 tensor - number of selected blocks per row
    """
    valid_entry_num = block_map.to(torch.int32).sum(dim=-1)

    B, H, x, y = block_map.shape

    one_matrix = torch.ones((B, H, x, y), dtype=torch.int32, device=block_map.device)
    cum_matrix = torch.cumsum(one_matrix, dim=-1)
    masked_cum_matrix = cum_matrix * block_map.to(torch.int32)
    filled_matrix = masked_cum_matrix.clone()
    filled_matrix[~block_map] = 10000000
    lut = torch.sort(filled_matrix, dim=-1)[0] - 1 # make index start from 0
    lut[:, :, :, 1:] = lut[:, :, :, 1:] - lut[:, :, :, :-1]  # difference to get relative offsets

    return lut.to(torch.int32), valid_entry_num.to(torch.int32)


def block_map_to_block_offset(block_map):
    """
    Convert block_map (boolean mask) directly to block_offset format (absolute indices).
    This is more efficient than block_map -> lut -> block_offset for FA2 sparse_attn_func.

    Args:
        block_map: (B, H, x, y) boolean tensor, True indicates selected blocks

    Returns:
        block_offset: (B, H, x, y) int32 tensor - absolute indices of selected blocks
        block_count: (B, H, x) int32 tensor - number of selected blocks per row

    Example:
        >>> block_map = torch.tensor([[[[0, 0, 1, 0, 0, 1, 0, 1, 0, 1]]]], dtype=torch.bool)
        >>> # Selected positions: [2, 5, 7, 9]
        >>> block_offset, block_count = block_map_to_block_offset(block_map)
        >>> print(block_offset)  # [[[[2, 5, 7, 9, large_num, ...]]]]
        >>> print(block_count)   # [[[4]]]
    """
    block_count = block_map.to(torch.int32).sum(dim=-1)

    B, H, x, y = block_map.shape

    # Create cumulative index matrix [0, 1, 2, 3, ...]
    one_matrix = torch.ones((B, H, x, y), dtype=torch.int32, device=block_map.device)
    cum_matrix = torch.cumsum(one_matrix, dim=-1) - 1  # [0, 1, 2, 3, ...]

    # Mask to keep only selected positions
    masked_cum_matrix = cum_matrix * block_map.to(torch.int32)
    filled_matrix = masked_cum_matrix.clone()

    # Fill unselected positions with large number for sorting
    filled_matrix[~block_map] = 10000000

    # Sort to get absolute indices of selected blocks
    block_offset = torch.sort(filled_matrix, dim=-1)[0]

    # Note: No difference operation needed - we want absolute indices!

    return block_offset.to(torch.int32), block_count.to(torch.int32)

@triton.jit
def compress_kernel(
    X, XM,
    L: tl.constexpr,
    D: tl.constexpr,
    BLOCK_L: tl.constexpr,
):
    idx_l = tl.program_id(0)
    idx_bh = tl.program_id(1)

    offs_l = idx_l * BLOCK_L + tl.arange(0, BLOCK_L)
    offs_d = tl.arange(0, D)

    x_offset = idx_bh * L * D
    xm_offset = idx_bh * ((L + BLOCK_L - 1) // BLOCK_L) * D
    x = tl.load(X + x_offset + offs_l[:, None] * D + offs_d[None, :], mask=offs_l[:, None] < L)

    nx = min(BLOCK_L, L - idx_l * BLOCK_L)
    # x_mean = tl.sum(x, axis=0, dtype=tl.float32) / nx
    x_mean = tl.sum(x, axis=0).to(tl.float32) / nx
    tl.store(XM + xm_offset + idx_l * D + offs_d, x_mean.to(XM.dtype.element_ty))


def mean_pool(x, BLK):
    assert x.is_contiguous()

    B, H, L, D = x.shape
    L_BLOCKS = (L + BLK - 1) // BLK
    x_mean = torch.empty((B, H, L_BLOCKS, D), device=x.device, dtype=x.dtype)

    grid = (L_BLOCKS, B * H)
    compress_kernel[grid](x, x_mean, L, D, BLK)
    return x_mean


def get_block_map(q, k, topk_ratio, BLKQ=64, BLKK=64):
    arg_k = k - torch.mean(k, dim=-2, keepdim=True) # smooth-k technique in SageAttention
    pooled_qblocks = mean_pool(q, BLKQ)
    pooled_kblocks = mean_pool(arg_k, BLKK)
    pooled_score = pooled_qblocks @ pooled_kblocks.transpose(-1, -2)

    K = pooled_score.shape[-1]
    topk = min(K, int(topk_ratio * K))
    lut = torch.topk(pooled_score, topk, dim=-1, sorted=False).indices

    sparse_map = torch.zeros_like(pooled_score, dtype=torch.int8)
    sparse_map.scatter_(-1, lut, 1)
    return sparse_map, lut, topk