• Po Yen Chen's avatar
    [CK_TILE] Add fmha fwd N-Warp S-Shuffle pipeline (fmha fwd splitkv pipeline variant) (#1705) · 37cdbf4f
    Po Yen Chen authored
    
    
    * Add check for zero values
    
    * Add static assertions
    
    * Remove invalid option '-e' in smoke_test.sh
    
    * Use correct path of smoke_test.sh
    
    * Avoid zero-sized shared memory array
    
    * Add warning comment
    
    * Replace expr by integer_divide_ceil() call
    
    * Use more readable constant names
    
    * Write down assumption as static assertion
    
    * Add more diagnostic error messages
    
    * Fix wrong BlockWarps when using default pipeline policy
    
    * Add more static assertions for A LDS desc
    
    * Allow using vector size < 8 for data type fp16/bf16
    
    * Align vector size between DRAM dist & LDS desc
    
    * Remove no-longer used func decl
    
    * Fix wrong displayed piepline name
    
    * Undo policy template changes for tile_example_gemm_basic
    
    * Add missing space and make error message stands out
    
    * Unify print precision
    
    * Add missing include directive <iomanip>
    
    * Replace constant 64 by get_warp_size() call
    
    * Replace constant 128 by named variable: BankLength
    
    * Add kAMBlock/kBNBlock attributes
    
    * Allow usig different A/B warp dist for multiple blocks
    
    * Add helper function to get warp dist encodings
    
    * Add 4x64x4 fp16 warp gemm attribute impl
    
    * Complete the A/B warp dist encoding logic
    
    * Fix wrong thread mapping for C matrix
    
    * Use smaller vector size for small tile
    
    * Add static assert to block unsupported warp gemm impl
    
    * Extract common code out as helper method
    
    * Add 4x64x16 fp16 warp gemm type alias
    
    * Add comment to warning developers
    
    * Undo WarpGemmAtrributeMfma<> changes
    
    * Use more clear static assertion error message
    
    * Add trivial wrapper to get warp dstr encodings
    
    * Only transpose warp gemm result if it's square
    
    * Fix compilation error
    
    * Support multi-block warp gemm (on N direction)
    
    * Remove duplicated code
    
    * Fix output encoding of warp gemm
    
    * Fix wrong shape of WarpGemmAtrributeMfmaIterateK<>
    
    * Remove unused code
    
    * Fix wrong shape of WarpGemmAttributeMfmaImplF16F16F32M4N64K4
    
    * Add type config for bf16_t
    
    * Add 4x64x16 bf16 warp gemm
    
    * Update WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
    
    * Add 64x4x4 fp16/bf16 warp gemm impl
    
    * Add 64x4x16 fp16/bf16 warp gemm
    
    * Add static assertion for better error diagnostic
    
    * Get Q dram dstr directly form block gemm
    
    * Add missing header: fused_moe.hpp
    
    * Allow specifying different warp-gemm for gemm0 & gemm1
    
    * Store P matrix into LDS before gemm1
    
    * Fix inconsistant kernel name
    
    * Remove constraint on gemm0 & gemm1 block warps
    
    * Remove unsupported vector size from checking list
    
    * Allow using 4x64x16 warp gemm for gemm0
    
    * Finish policy customization
    
    * Finish pipeline modification
    F#
    
    * Use block warps in codegen
    
    * Fix wrong rank of m_lds_window origin
    
    * Use better distributed tensor
    
    * Make P-store earlier
    
    * Remove duplicated experssions
    
    * Remove unnecessary tile window
    
    * Create new files for new splitkv pipeline
    
    * Separate old/new pipeline codegen logic
    
    * Sync changes form develop
    
    * Undo gemm kernel/pipeline changes
    
    * Undo gemm example changes
    
    * Remove blank lines
    
    * Fix typo
    
    * Use new warp gemm interface
    
    * Fix link error
    
    * Fix wrong pipeline tag
    
    * Fix more link error
    
    * Avoid unnecessary padding
    
    * Always use vector load for K
    
    * Padding on fastest dimension when necessary
    
    * Force padding Q on hdim_q
    
    * Set high dimension padding flag to false
    
    * Re-format headers
    
    * Use warps=<1, 4, 1> for both gemm0 & gemm1
    
    * Fix complilation errors
    
    * Remove m/l shuffle logics
    
    * Ignore duplicate data when write lse_acc
    
    * Use gemm0 block warps as lds tile width
    
    * Remove hard-coded numbers
    
    * Fix wrong distribution width
    
    * Remove unnecessary code
    
    * Add s_barrier before writing to LDS
    
    * Store Q into LDS before gemm0
    
    * Fix wrong Q tile size
    
    * Use simple Q lds descriptor for debuging
    
    * Use more realistic Q lds descriptor
    
    * Add comment & use better variable name
    
    * Make Q lds space not overlapped with others
    
    * Remove unnecessary block_tile_reduce_sync() call
    
    * Move Q load statements
    
    * Move block_sync_lds() right before use
    
    * Re-order instructions
    
    * Remove necessary lambda expression
    
    * Use 8 threads on kMaxSplits direction while doing reduction
    
    * Tiny correction for using 8 threads on kMaxSplits direction for combine kernel
    
    * Padding num_split direction of o_acc tile window to 4x
    
    * Update splitkv combine pipeline design
    
    * Add kN1 back to splitkv combine pipeline problem
    
    * Fix compilation errors
    
    * Add missing template parameter
    
    * Fix wrong splitkv combine kernel name
    
    * Fix wrong origin
    
    * Fix wrong LDS descriptor shape
    
    * Fix sync & reduction logics
    
    * Remove unnecessary static assertions
    
    * Extract tile size computation logics
    
    * Make sure we can reuse padding flags in combine kernels
    
    * Rename variables
    
    * Use OaccDataType in BlockFmhaSplitKVCombinePipelineTileSizes<>
    
    * Remove unnecessary static assertion
    
    * Fix function name typo
    
    * Add constraint on kN1 template parameter
    
    * Hide K tile loading latency in earlier iteration
    
    * Fix wrong splitkv kernel name
    
    * Use s_shuffling to replace p_shuffling which removes the needs of cross-warp reduction
    
    * Rename pipeline
    
    * Fix wrong pipeline name attribute
    
    * Add GetAlignmentQ() for NWarpSShuffle pipeline
    
    * Separate Q tile into dram tile & register tile concepts
    
    * Remove non-squre warp gemm transpose c type alias
    
    * Fallback tile size changes for fmha fwd splitkv
    
    * Remove redundant change
    
    * Refine naming for the S tile
    
    * Use better naming of the S tile dstr (read from lds)
    
    * Share Q lds with K lds
    
    * Tiny change
    
    * Fix with using static_for for passing CI checking
    
    ---------
    Co-authored-by: default avatarQianfeng Zhang <Qianfeng.Zhang@amd.com>
    37cdbf4f
fmha.hpp 3.24 KB