Add FP8 support to CP implementation with KV P2P (#1114)
* add window_size to AttnFuncWithCP Signed-off-by:Xiaowei Ren <xren@nvidia.com> * add seq_offsets_qkvo for cudnn thd Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * add seq_offsets_qkvo to AttnFuncWithCP Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix seq_offsets calculation of cudnn thd Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * remove a thd assert Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix bias for thd test Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * add thd test for cudnn FA with CP Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * skip GQA/MQA test for cuDNN THD Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * make sure seq_offsets are computed with qkv_group of hd_hd_hd while CP>1 Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix seq_offsets inputs Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * remove two comments Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix attn mask type for cudnn thd with cp Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix attn_mask_type check Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix attn_mask_type for cudnn fa with thd Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix a typo Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix out dout in bwd Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * assert cudnn+thd does not support attn bias Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * check if attn_mask_type has padding Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * minor change Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * change cp test batch size to 2 Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix code format Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix two assert info Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix assert comment Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix assert comments Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * minor fix Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix assert comments Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * assert swa+CP cannot work with thd format Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * add a new CP function for swa Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * add a missing dgrads Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * minor change Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * add draft fwd function for swa+cp Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * minor change Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * enable flash attention for swa+cp Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * remove an assert of swa+cp Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * call SWAFuncWithCP for swa+cp Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * typo fix Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * use 2hd layout Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change qkv_format check Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * add a code comment Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * tensor shape bug fix Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tensor shape fix Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * add function to compute cu_seqlens of a cp rank Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * add cu_seqlens and cu_seqlens_padded to context parallelism Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * typo fix Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * minor change Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix FlashAttention output sequence length Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix cu_seqlens_kv_per_step calculation Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * zero dQKV for ending padded tokens Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * zero dQKV tensors of FlashAttention Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix softmax_lse correction Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * remove padded tokens of KV to save comounication Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * do not need to zero dkv for FlashAttention any mroe Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * zero out tensors Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * remove redundant code Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix CP unit test Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix kv shape of cp test with thd format Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * update cp unit test Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add simple code framework Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * try not to have a separate CP function for SWA Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * backup some code change Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * back up code Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * clean up fwd implementation of SWAFuncWithCP Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * remove redundant code Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * code cleaning Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix assert info Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * reduce kv chunk concat overheads Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * minor change Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * make AttnFuncWithCP and SWAFuncWithCP have same API Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * add a docstring Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * preliminary implementation of SWAFuncWithCP forward seems working Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix output shape of SWAFuncWithCP Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * code refactoring for FlashAttention and add a code placeholder for bwd Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * use gather_along_first_dim Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * finish the preliminary implementation of bwd Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * remove redundant code Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix assert condition Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * add draft implementation of SWA+CP with FusedAttention Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix attention mask type of swa+cp Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * code cleaning Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * add qkv_layout Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * add missing window_size argument Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * typo fix Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix kv shape of swa+cp Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * bug and typo fix Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix dout shape Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * add multi stream in fwd of swa+cp Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * save chunk_ids_to_kv_ag in fwd Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * add multi stream in bwd of swa+cp Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * minor fix to cp stream sync Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * rename AttnFuncWithCP Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * check if window size is None Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix docstring of AttnFuncWithCP Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * minor fix Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * add env var for users to choose KV ag or KV p2p Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * update cp tests Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix window size in cp unit test Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix pytest skip messages Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * add cp_comm_type into API Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * code cleaning Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add deterministic konb in cuDNN fused attn backend Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * pass fp8 and fp8_meta to attn_func_with_cp Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * assert only Fused Attn can support FP8+CP Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * remove redundant assert Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * add a fwd draft implementation of FP8 + CP Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * save fp8 and fp8_meta Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * assert sequence length divisible requirements Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove a redundant qkv_layout compute Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * if condition change Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * some typo fix Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * add support table of context parallelism Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * typo and code format fix Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * do not print multiple disabling messages Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * bug fix Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix aux_ctx_tensors of FP8 Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * bug fix Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix device in torch.arange and adjust code for the PR of MLA Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * commit code change for FP8+CP Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * commit more code change for FP8+CP Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * commit more fp8 code for FP8+CP Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * bug fixes Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * bug fix Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * cast merged CP results from FP32 to BF16 Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * typo fix Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * minor change Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix softmax_lse Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix some bugs of FP8 dkv exchange Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * typo fix Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * add FP8 unit test Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix typos and clean asserts Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix get_p2p_comm_info Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * fix dkv p2p exchange Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * minor fix Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * change FP8 dkv P2P to A2A Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * add FP8+CP unit test Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * typo fix Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * assert amax reduction is needed for FP8+CP Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove duplicated code Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * destroy process group in CP unit test Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * remove interval from fp8_recipe because it has been deprecated Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * try to fix the failed CP test with the latest CI pipeline Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove redundant f before string Signed-off-by:
Xiaowei Ren <xren@nvidia.com> * change META_O_CP Signed-off-by:
Xiaowei Ren <xren@nvidia.com> --------- Signed-off-by:
Xiaowei Ren <xren@nvidia.com> Co-authored-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by:
Xiaowei Ren <xren@cs-cw-dfw-login-01.cm.cluster>
Showing
Please register or sign in to comment