import torch import triton import triton.language as tl @triton.jit def compress_kernel( X, XM, L: tl.constexpr, D: tl.constexpr, BLOCK_L: tl.constexpr, ): idx_l = tl.program_id(0) idx_bh = tl.program_id(1) offs_l = idx_l * BLOCK_L + tl.arange(0, BLOCK_L) offs_d = tl.arange(0, D) x_offset = idx_bh * L * D xm_offset = idx_bh * ((L + BLOCK_L - 1) // BLOCK_L) * D x = tl.load(X + x_offset + offs_l[:, None] * D + offs_d[None, :], mask=offs_l[:, None] < L) nx = min(BLOCK_L, L - idx_l * BLOCK_L) x_mean = tl.sum(x, axis=0, dtype=tl.float32) / nx tl.store(XM + xm_offset + idx_l * D + offs_d, x_mean.to(XM.dtype.element_ty)) def mean_pool(x, BLK): assert x.is_contiguous() B, H, L, D = x.shape L_BLOCKS = (L + BLK - 1) // BLK x_mean = torch.empty((B, H, L_BLOCKS, D), device=x.device, dtype=x.dtype) grid = (L_BLOCKS, B * H) compress_kernel[grid](x, x_mean, L, D, BLK) return x_mean def get_block_map(q, k, topk_ratio, BLKQ=64, BLKK=64): arg_k = k - torch.mean(k, dim=-2, keepdim=True) # smooth-k technique in SageAttention pooled_qblocks = mean_pool(q, BLKQ) pooled_kblocks = mean_pool(arg_k, BLKK) pooled_score = pooled_qblocks @ pooled_kblocks.transpose(-1, -2) K = pooled_score.shape[-1] topk = min(K, int(topk_ratio * K)) lut = torch.topk(pooled_score, topk, dim=-1, sorted=False).indices sparse_map = torch.zeros_like(pooled_score, dtype=torch.int8) sparse_map.scatter_(-1, lut, 1) return sparse_map, lut, topk def get_cuda_arch(device_index): major, minor = torch.cuda.get_device_capability(device_index) return f"sm{major}{minor}"