• rocking's avatar
    Support AMD ROCm on FlashAttention 2 (#1010) · d8f104e9
    rocking authored
    
    
    * Support ck in fmha
    
    * Add ck submodule
    
    * Do not return lse if return_softmax == false
    
    * Use receipt to speed up ck compile time
    
    * Integrate new version of ck_tile
    
    * Support dropout for mha_fwd()
    
    * Add dropout to mha_varlen_fwd()
    
    * Update ck to develop
    
    * Extract padding function for dropout randval
    
    * Extract randval transformation function
    
    * Sync the code structure and coding style with FA
    
    * Remove this line, c++ api will handle this.
    Sync with test_flash_attn.py
    
    * fix compile error
    
    * Add mha_bwd
    
    * Generate dropout seed and offset from user generator
    
    * update CK
    
    * Add mha_varlen_bwd
    
    * Use same python as build flash-attn to generate ck kernel
    
    * Fix bug of group mode fwd about returning softmax lse
    
    * larger the test tollerance
    
    * Add test_flash_attn_output() and test_flash_attn_varlen_output()
    
    * Always fill softmax_lse
    
    * Remove duplicate benchmark script, since we already implement mha_bwd
    
    * Refine get value from tuple
    
    * Use default parameter for stream_config
    
    * unblock all platform
    
    * Add comment
    
    * refine the test code
    
    * Refine naming
    
    * Add unpack to namespace
    
    * Do not hardcode the warp size 64
    
    * Add more targets
    
    * Add README
    
    * Optimize mha_fwd if seqlen_q == 1
    
    * Support get_wheel_url for rocm
    
    * Detect rocm environment by pytorch's IS_HIP_EXTENSION
    
    * update to lastest ck
    
    * Add necessary compile flag
    
    * Sync the api with upstream FA
    
    ---------
    Co-authored-by: default avatarcarlushuang <carlus.huang@amd.com>
    Co-authored-by: default avatarYichen Yan <wenji.yyc@alibaba-inc.com>
    Co-authored-by: default avatarPo Yen Chen <PoYen.Chen@amd.com>
    Co-authored-by: default avatarYichen Yan <oraluben@outlook.com>
    d8f104e9
flash_api.cpp 5.66 KB