# Block Sparse Attention As prompt lengths continue to increase, the computational and memory bandwidth demands of Large Language Models (LLMs) grow significantly, making efficient processing more challenging. However, by fully leveraging the inherent sparsity in attention patterns, we can optimize the model’s performance, effectively reducing inference costs in computation. This approach not only enhances the efficiency of LLMs but also enables them to handle longer and more complex prompts without a proportional increase in resource consumption. To this end, we introduce Block Sparse Attention, a library of sparse attention kernels that supports various sparse patterns, including streaming attention with token granularity, streaming attention with block granularity, and block-sparse attention. By incorporating these patterns, Block Sparse Attention can significantly reduce the computational costs of LLMs, thereby enhancing their efficiency and scalability. We release the implementation of Block Sparse Attention, which is modified base on [FlashAttention](https://github.com/Dao-AILab/flash-attention) 2.4.2.  ## News - [2024/10] We release both fwd pass and bwd pass of Block Sparse Attention. ## Features We have four patterns supported in Block Sparse Attention: 1. dense attention Calculate the full attention matrix. 2. streaming atteniton with token granularity Calculate the attention with a fixed number of sink tokens and local tokens. You can refer to [StreamingLLM](https://arxiv.org/abs/2309.17453) for more details. 3. streaming attention with block granularity, block_size = 128 Calculate the attention with a fixed number of sink blocks and local blocks. 4. blocksparse attention, block_size = 128 Take in a block mask and calculate the attention with the block mask. **Importantly, we support assigning different patterns for different heads.** You can use `head_mask_type` to specify the pattern for each head. This is a list of quiry head number of integers. For one head, `mask_type = 0` means dense attention, `mask_type = -1` means streaming attention (either block streaming or exact streaming), and `mask_type = 1` means blocksparse attention, the head will use `basemask[mask_type - 1]` as its attention mask. For example, if you have 8 heads and ```python head_mask_type = [1, 1, 0, 0, 0, -1, 0, -1] ``` This means head0, head1 use blocksparse mask, head2 to head4 and head 6 use dense mask, and head 5 and head 7 use streaming mask. The interface is: ```python from block_sparse_attn import block_sparse_attn_func block_sparse_attn_func( q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, head_mask_type, streaming_info, base_blockmask, max_seqlen_q_, max_seqlen_k_, p_dropout, deterministic=False, softmax_scale=None, is_causal=False, exact_streaming=False, return_attn_probs=False, ) ``` ```python from block_sparse_attn import block_streaming_attn_func block_streaming_attn_func( q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, head_mask_type, streaming_info, max_seqlen_q, max_seqlen_k, p_dropout, deterministic=False, softmax_scale=None, is_causal=True, return_attn_probs=False, ) ``` ```python from block_sparse_attn import token_streaming_attn_func # bwd pass is not yet supported token_streaming_attn_func( q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, head_mask_type, streaming_info, max_seqlen_q, max_seqlen_k, deterministic=False, softmax_scale=None, return_attn_probs=False, ) ``` ## Performance ### Block Sparse Speedup


