• Po Yen Chen's avatar
    [CK_TILE] fmha forward split-kv + combine kernels (#1338) · 0cb2e06d
    Po Yen Chen authored
    
    
    * FA fwd dropout
    
    * FA bwd
    
    * epilogue reuse
    
    * CMakeLists update
    
    * [CK_TILE] support alibi (#1269)
    
    * add alibi support
    
    * fix code
    
    * update code based on comment
    
    * Support more hdim
    
    * fix fp8 bias
    
    * support seqlen_k=0 case
    
    * remove unused printf
    
    * fix format
    
    ---------
    Co-authored-by: default avatarrocking <ChunYu.Lai@amd.com>
    
    * now fwd/bwd can build
    
    * bwd alibi
    
    * add bwd validation stream_config
    
    * update generated filenames
    
    * update bwd kernel launch
    
    * CK_TILE_HOST_DEVICE in philox
    
    * Transpose -> transpose
    
    * format
    
    * format
    
    * format
    
    * Generate the instance for FA required
    
    * format
    
    * fix error in WarpGemm
    
    * Add num_splits option and dummy split-kv api method
    
    * Generate fmha_fwd_splitkv()
    
    * Add SplitKV kernel codegen logics
    
    * Add SplitKV combine kernel codegen logics
    
    * Fix mismatched return type
    
    * Clean-up code
    
    * Replace sentinel value before storing
    
    * Fix wrong layout of LSE/LSEacc/Oacc
    
    * Format codes
    
    * Fix o_acc memory error
    
    * Fix wrong kBlockSize used in policy
    
    * Reduce # of combine kernels
    
    * Fix split-kv combine kernel name
    
    * Fix wrong LDS indexing logics
    
    * Fix wrong loop counter step logic
    
    * Undo vector size changes
    
    * Remove no-longer used field
    
    * Remove in-consistent comment
    
    * Remove debug statements in example
    
    * Remove more debug statements
    
    * Add constness to local variables
    
    * Clearn up generate.py
    
    * Fix unstable clang-format comment
    
    * Remove unused include directive
    
    * Use shorter template parameter name
    
    * Enable non-split-kv blobs
    
    * Update license date
    
    * Print num_splits conditionally
    
    * Undo disabling data types
    
    * Remove unnessary tile size for fp8
    
    * Fix wrong pipeline args for fp8
    
    * Fix example output format
    
    * Remove more debug code in combine pipeline
    
    * Add stride kernel arguments for LSE/O acc workspace
    
    * Re-order split-kv pipeline call operator arguments
    
    * Pass LSE/O strides in kernel argument
    
    * Re-order pipeline call operator arguments
    
    * Use tensor_descriptor to locate LSEacc elements
    
    * Support providing invalid element for tensor view
    
    * Set invalid element value for LSEacc tensor view
    
    * Remove hand-written store_tile() code
    
    * Remove necessary value-overwrite logic
    
    * Add transposed lds descriptor
    
    * Support load_tile() for tile_window_with_static_lengths<>
    
    * Undo removing necessary value-overwrite logic
    
    * Use read descriptor to locate lds elements
    
    * Simplify pipeline source code
    
    * Add constraint to kMaxSplits
    
    * Default use kMaxSplits=64 in generate.py
    
    * Revert "Add constraint to kMaxSplits"
    
    This reverts commit 0a2132d758042e6fb0292f4e354909b8a4d1c118.
    
    * Revert "Default use kMaxSplits=64 in generate.py"
    
    This reverts commit c7d9c80b77320aec6559222bed7d47adcaefe4e3.
    
    * Decide alignment by the padding parameter
    
    * Remove no-longer used utility functions
    
    * Remove not-working code
    
    * Add comment & remove no-longer used code
    
    * Fix computation errors
    
    * Add heuristic to override num_splits option
    
    * Add constraint to kMaxSplits
    
    * Fix compilation error
    
    * Clean up pipeline code
    
    * Wrap pointer access as lambda function
    
    * Rename confusing methods
    
    * Use kLogMasSplits as template parameter
    
    * Finish splitkv combine kernel codegen
    
    * Update kMaxSplits limit
    
    * Use smaller kM0 for splitkv combine kernel
    
    * Ignore droupout flag in splitkv pipeline
    
    * Unify flag usage
    
    * Add back flag kStoreLSE
    
    * Merge lambda calls in pipeline
    
    * Fix compilation errors
    
    * Avoid all empty splits
    
    * Always check for empty loop in splitkv pipelines
    
    * Re-order parameters
    
    * Remove redundant p_drop option check
    
    * Add traits/problem for fwd splitkv kernel
    
    * Conditionally enable uneven split boundary checks
    
    * Add comment for the splitkv traits field
    
    * Change even split criteria
    
    * Re-order statements
    
    * Refine occupancy value for hdim=128&256
    
    * Refine occupancy value for hdim=32&64
    
    * Remove redundant kernel argument
    
    * Separate fmha bwd codegen logics
    
    * Separate fmha fwd codegen logics
    
    * Remove redundant direction parameter in fwd&bwd codegen logics
    
    * Support generate multiple APIs for an example
    
    * Let 'api' an alias of 'direction' option
    
    * Remove choices for the 'direction' option
    
    * Use dictionary to config all the functions
    
    * Move fmha splitkv codegen logics to other file
    
    * Add fwd_splitkv api for tile_example_fmha_fwd
    
    ---------
    
    Co-authored-by: danyao12 <danyao12>
    Co-authored-by: default avatarcarlushuang <carlus.huang@amd.com>
    Co-authored-by: default avatarrocking <ChunYu.Lai@amd.com>
    Co-authored-by: default avatarJing Zhang <jizhan@amd.com>
    0cb2e06d
fmha.hpp 3.2 KB