[CK_TILE] Add fmha fwd N-Warp S-Shuffle pipeline (fmha fwd splitkv pipeline variant) (#1705)
* Add check for zero values
* Add static assertions
* Remove invalid option '-e' in smoke_test.sh
* Use correct path of smoke_test.sh
* Avoid zero-sized shared memory array
* Add warning comment
* Replace expr by integer_divide_ceil() call
* Use more readable constant names
* Write down assumption as static assertion
* Add more diagnostic error messages
* Fix wrong BlockWarps when using default pipeline policy
* Add more static assertions for A LDS desc
* Allow using vector size < 8 for data type fp16/bf16
* Align vector size between DRAM dist & LDS desc
* Remove no-longer used func decl
* Fix wrong displayed piepline name
* Undo policy template changes for tile_example_gemm_basic
* Add missing space and make error message stands out
* Unify print precision
* Add missing include directive <iomanip>
* Replace constant 64 by get_warp_size() call
* Replace constant 128 by named variable: BankLength
* Add kAMBlock/kBNBlock attributes
* Allow usig different A/B warp dist for multiple blocks
* Add helper function to get warp dist encodings
* Add 4x64x4 fp16 warp gemm attribute impl
* Complete the A/B warp dist encoding logic
* Fix wrong thread mapping for C matrix
* Use smaller vector size for small tile
* Add static assert to block unsupported warp gemm impl
* Extract common code out as helper method
* Add 4x64x16 fp16 warp gemm type alias
* Add comment to warning developers
* Undo WarpGemmAtrributeMfma<> changes
* Use more clear static assertion error message
* Add trivial wrapper to get warp dstr encodings
* Only transpose warp gemm result if it's square
* Fix compilation error
* Support multi-block warp gemm (on N direction)
* Remove duplicated code
* Fix output encoding of warp gemm
* Fix wrong shape of WarpGemmAtrributeMfmaIterateK<>
* Remove unused code
* Fix wrong shape of WarpGemmAttributeMfmaImplF16F16F32M4N64K4
* Add type config for bf16_t
* Add 4x64x16 bf16 warp gemm
* Update WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
* Add 64x4x4 fp16/bf16 warp gemm impl
* Add 64x4x16 fp16/bf16 warp gemm
* Add static assertion for better error diagnostic
* Get Q dram dstr directly form block gemm
* Add missing header: fused_moe.hpp
* Allow specifying different warp-gemm for gemm0 & gemm1
* Store P matrix into LDS before gemm1
* Fix inconsistant kernel name
* Remove constraint on gemm0 & gemm1 block warps
* Remove unsupported vector size from checking list
* Allow using 4x64x16 warp gemm for gemm0
* Finish policy customization
* Finish pipeline modification
F#
* Use block warps in codegen
* Fix wrong rank of m_lds_window origin
* Use better distributed tensor
* Make P-store earlier
* Remove duplicated experssions
* Remove unnecessary tile window
* Create new files for new splitkv pipeline
* Separate old/new pipeline codegen logic
* Sync changes form develop
* Undo gemm kernel/pipeline changes
* Undo gemm example changes
* Remove blank lines
* Fix typo
* Use new warp gemm interface
* Fix link error
* Fix wrong pipeline tag
* Fix more link error
* Avoid unnecessary padding
* Always use vector load for K
* Padding on fastest dimension when necessary
* Force padding Q on hdim_q
* Set high dimension padding flag to false
* Re-format headers
* Use warps=<1, 4, 1> for both gemm0 & gemm1
* Fix complilation errors
* Remove m/l shuffle logics
* Ignore duplicate data when write lse_acc
* Use gemm0 block warps as lds tile width
* Remove hard-coded numbers
* Fix wrong distribution width
* Remove unnecessary code
* Add s_barrier before writing to LDS
* Store Q into LDS before gemm0
* Fix wrong Q tile size
* Use simple Q lds descriptor for debuging
* Use more realistic Q lds descriptor
* Add comment & use better variable name
* Make Q lds space not overlapped with others
* Remove unnecessary block_tile_reduce_sync() call
* Move Q load statements
* Move block_sync_lds() right before use
* Re-order instructions
* Remove necessary lambda expression
* Use 8 threads on kMaxSplits direction while doing reduction
* Tiny correction for using 8 threads on kMaxSplits direction for combine kernel
* Padding num_split direction of o_acc tile window to 4x
* Update splitkv combine pipeline design
* Add kN1 back to splitkv combine pipeline problem
* Fix compilation errors
* Add missing template parameter
* Fix wrong splitkv combine kernel name
* Fix wrong origin
* Fix wrong LDS descriptor shape
* Fix sync & reduction logics
* Remove unnecessary static assertions
* Extract tile size computation logics
* Make sure we can reuse padding flags in combine kernels
* Rename variables
* Use OaccDataType in BlockFmhaSplitKVCombinePipelineTileSizes<>
* Remove unnecessary static assertion
* Fix function name typo
* Add constraint on kN1 template parameter
* Hide K tile loading latency in earlier iteration
* Fix wrong splitkv kernel name
* Use s_shuffling to replace p_shuffling which removes the needs of cross-warp reduction
* Rename pipeline
* Fix wrong pipeline name attribute
* Add GetAlignmentQ() for NWarpSShuffle pipeline
* Separate Q tile into dram tile & register tile concepts
* Remove non-squre warp gemm transpose c type alias
* Fallback tile size changes for fmha fwd splitkv
* Remove redundant change
* Refine naming for the S tile
* Use better naming of the S tile dstr (read from lds)
* Share Q lds with K lds
* Tiny change
* Fix with using static_for for passing CI checking
---------
Co-authored-by:
Qianfeng Zhang <Qianfeng.Zhang@amd.com>
Showing
This diff is collapsed.
Please register or sign in to comment