• Tong WU's avatar
    [Example] Add efficient attention sink backward implementations and tests (#877) · ec24561a
    Tong WU authored
    * [Example] Add a new example to support attention sink for MHA
    
    - Introduced a new example script for multi-head attention (MHA) with sliding window attention and sink tokens.
    - Added a reference attention function to validate the implementation against PyTorch.
    - Included argument parsing for command-line execution of the example.
    
    * [Example] Replace MHA sink forward example with updated implementation
    
    - Removed the old example script for multi-head attention (MHA) with sliding window attention and sink tokens.
    - Introduced a new example script that modifies the attention mechanism to enhance performance and maintainability.
    - Updated argument parsing and reference functions to align with the new implementation.
    
    * Enhance MHA sink example with sliding window support
    
    - Added a `window_size` parameter to the `flashattn` function to enable sliding window attention.
    - Implemented assertions to ensure `window_size` is compatible with `block_N`.
    - Updated the main function to include a `tune` option for performance tuning.
    - Introduced a new test file to validate both full attention and sliding window scenarios.
    - Adjusted FLOPS calculation to account for the sliding window configuration.
    
    * lint
    
    * [Fix] Add checkinf process to fix the bug of swa
    
    * Migrate to BSHD layout to align with triton baselines
    
    * lint
    
    * fix typo
    
    * Refactor MHA sink example to use seq_q and seq_kv parameters to accommodate the new sequence length parameters.
    
    * Add GQA sink example for optimized attention mechanism & lint fix
    
    * fix several typos and bugs
    
    * lint
    
    * fix speed issues of swa
    
    * Add flash attention example with backward pass for BHSD layout and corresponding test cases
    
    * Add backward pass implementation for flash attention with sinks and corresponding test case
    
    * fix lint and typo
    
    * Optimze the calculation of `dsinks`
    
    * Add support for swa backward and update examples
    
    * fix previous typos
    
    * Add example for GQA sink backward pass and update tests for both MHA and GQA sinks
    
    * fix lint
    
    * fix previous typos
    
    * typo
    ec24561a
example_gqa_sink_bwd_bhsd.py 21.4 KB