• Po Yen Chen's avatar
    [CK_TILE] fmha fwd splitkv optimization for decode (seqlen_q=1) (#1789) · 24b12d04
    Po Yen Chen authored
    
    
    * Update license year
    
    * Add initial code to override decode problem
    
    * Fix splitkv traits/args overriding error
    
    * Reshape and transpose lse for decode
    
    * Remove debug code
    
    * Prettify example code
    
    * Use better function name
    
    * Add kMergeNumHeadGroupsSeqLenQ flag
    
    Kernel user can use this switch to turn on/off optimization for
    some problem sizes
    
    * Add missing flag declarations
    
    * Default turn off kMergeNumHeadGroupsSeqLenQ in codegen
    
    * Group similar statements together
    
    * Remove assumption of seqlen_q=1
    
    * Remove kMergeNumHeadGroupsSeqLenQ from splitkv combine kernel
    
    * Support kMergeNumHeadGroupsSeqLenQ=true in fmha splitkv kernel
    
    * Run kMergeNumHeadGroupsSeqLenQ=true kernels when need
    
    * Fix group mode block skip logics
    
    * Undo changes of normal fwd kernel
    
    * Update in GridSize() and using GridSize() for splitkv kernel (#1799)
    
    ---------
    Co-authored-by: default avatarQianfeng <qianfeng.zhang@amd.com>
    24b12d04
fmha.hpp 3.03 KB