Integrate cuDNN frontend v1 to fused attention (#497)
* Integrate cuDNN frontend v1 to fused attention and miscellaneous fixes Signed-off-by:Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix jax/paddle for unit tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix jax/pytorch lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * simplify stride generation Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix and/or logic in get_backend Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix flag_max512 and test_numerics Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove v.contiguous() since get_qkv_layout covers it Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * skip fp8 tests for sm89 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * further fix jax CI Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix jax CI Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert mask type to comma-separated list Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix last two commits Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * integrate v1/pre-release-5 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * cleanup prerelease5 integration and fix FA2.1 commit Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * force dropout to 0 if not training Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix Jax CI Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * testing bias/alibi and padding+causal; add alibi to unfused DPA Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * set flag_arb to false when non determinism is not allowed Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * followup on prev commit; remove redundant python env var setting Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * WIP: minor tweaks for tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * prepare for tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix determinism logic for fused attn Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add bias to bwd Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix gpt_checkpointing/dpa_accuracy problem Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix some seg fault issues Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add failure notes Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove use of non-deter var for backend selection Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fix for lint and CI Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix workspace size in bwd and uncomment bias test Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix get_alibi and remove check_support Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update tests status Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove workspace_opt from FADescriptor_v1 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * disable arbitrary backend + post scale bias in Jax; waiting on PR 525 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up bhsd order Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * swap bias/rng_state order in aux_ctx_tensor and add bias to aux_ctx_tensor in _qkvpacked/_kvpacked API Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove support for padding_causal + cross for max512 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * change alibi bias to float32 for bias_1_4/5 tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * further clean up tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix thd fwd output shape for FlashAttention and add backend info for DPA Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix definition of workspace limit when dbias is present Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * further tweak DP_WORKSPACE_LIMIT definition Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * disallow alibi+no_mask for sdpa flash and update alibi tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update jax/paddle after PR525 and fix DP_WORKSPACE_LIMIT for dbias Jax tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * disable dbias for non-hopper archs Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix layernorm lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remode unused arg for lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove build dir in setup.py Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * change selection logic to prefer fused attn on sm90 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix distributed jax test Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix h and s order in header Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update to cudnn fe v1 branch Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove manual setting of workopt path due to dbias after v1 update Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix paddle CI Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add post_scale_bias and alibi to sdpa flash support matrix Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix support matrix in header files Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * move headers back to .cu and change seed/offset to int64 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update Megatron commit in L1 test and remove all prints in fused attn test Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix L1 Megatron test Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fp8 arg in L1 Megatron script Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * print only when debug flag is on Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove checkpointing loading to avoid loading other tests results Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by:
cyanguwa <8636796+cyanguwa@users.noreply.github.com>
Showing
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Please register or sign in to comment