- 06 Dec, 2025 1 commit
-
-
Kshitij Lakhani authored
* Add generic stripe_height support for load balancing Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Fix imports in test for deprecated jax.experimental.pjit Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Add test case for stripe_height greater than 1. Add stripe_height arg to reordering methods Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> * Add Striped 1 and 4 test cases. Refactor the Load Balancing test case. Fix the incorrect shape in striping inverser reordering Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> * Modify test code for CP + AG + THD + stripe height greater than 1 Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> * Add stripe_height arg to fused attn and fused attn fwd API. Add appropriate mask checks for AG+THD+CP and pick BRCM to be executed per rank. Add Fused Attn Primitive for CP + THD +AG + Striping. Add a method to reorder and all gather segment ids and offsets for kv Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> * TMP: Throwaway testing commit Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> * Add comments in primitive registration process Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> * TMP: Throwaway test commit Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Undoing incorrect rebase/merge leftovers Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> * TMP: Throwaway test commits Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> * Add support for calculating q and kv seqlens and offsets per rank for CP+THD+AG+SW+Striped>1 primitive Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> * Augment jax primitive register code comments Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> * Fix the array sizes and padding values returned for seqlens and offsets to fit what the fused attn primitive non cp computation Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Add support in new primitive for softmax_offset related changes. Put in missing primitive registering line in again. Increase the seqoffsets arrays lengths by 1 Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com> * Add new set of helper functions for seqlens and seqoffsets fo AG+THD+CP+Stripe>1 which accounts for batching and seq offsets size b+1 Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Add backward primitive for CP+THD+AG+Striped>1 Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Modify tests for backward primitive for CP+THD+AG+Striped>1 Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Move stripe_height along with other static args in fused_attn_bwd rule. Fix typo in CP+AG+TH+Striped>1 primitive Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com> * Code clean up: remove older version for calculating seqlens and offsets for CP+AG+THD+striped>1 primitive Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Add test for CP+THD+AG+Striped>1 Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Fix missing var Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add SWA tests for AG+Striped>1+CP+THD+SWA Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com> * Restoring test code Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com> * Remove assert preventing SWA code path in CP+AG+Striped primitive Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com> * Parametrize num_segments_per_seq in tests Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Clean up test code Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> Clean up test code in TE common Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> Clean up debug statements Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Rename stripe_height to stripe_size Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Code clean up and add additional comments Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> nit: Apply suggestions from code review Co-authored-by:
greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by:
Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Fix type on fused attn tests Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com> * Fix seqoffsets length to be passed onto FusedAttn primitive as it is b and not b+1 needed by cuDNN Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com> * Remove commented code Co-authored-by:
greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by:
Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Fix linting issues Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> Fix incorrect greptile change Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip THD test cases for CP + AG + Dual chunk. Skip BSHD cases for CP + AG + Striped>1. Correct the layout and shapr parameters passed to the tests Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com> * Pass stripe_size explicitly for ring attn tests for THD cases Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com> * Remove TODO Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com> * Explicitly fail if THD + AG is being used with a non padding causal mask Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * nit: Correct the ID for the test dist fused attn tests to account for cp*2 which is done under the hood Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Set num_segments_per_seq defaults to None instead of 0 Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Augment comments. Add ValueError for stripe_size=0 Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Test only 1 num_segments_per_seq combination for CP+AG+THD+Striped>1+SWA instead of 2. Modify the num segments and window size to easily to debug values Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Default stripe_size to None instead of 0. Modify stripe_size check for <=0 instead of ==0 Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove incorrectly added file Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Explicitly pass zero sized arrays for seg ids and pos in the CP + AG + Striped primitive rather than using the seqlens or the offsets as placeholders Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Fix linting errors Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add a deep dive doc for CP+THD+AG+Stripe>1+SWA regarding design considerations and decisions Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> * Put docs and pngs into it's separate dir Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> * Replace png screenshots with markdown coe blocks for the attention patterns. Remove unecessary pngs Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> * Add doc file to index.rst. Fix grammatical errors Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> --------- Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com> Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> Co-authored-by:
Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> Co-authored-by:
Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
-
- 21 Nov, 2025 1 commit
-
-
Kshitij Lakhani authored
* Remove unnecessary SWA calculation from _segment_ids_pos_to_seqlens_offsets Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
-
- 18 Nov, 2025 1 commit
-
-
Paweł Gadziński authored
* fix Signed-off-by:
Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * removed packed versions Signed-off-by:
Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by:
Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by:
Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by:
Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by:
Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by:
Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * jax Signed-off-by:
Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by:
Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by:
Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by:
Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by:
Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by:
Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by:
Pawel Gadzinski <pgadzinski@nvidia.com> * fixes Signed-off-by:
Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by:
Pawel Gadzinski <pgadzinski@nvidia.com> * fix: Signed-off-by:
Pawel Gadzinski <pgadzinski@nvidia.com> * fixes Signed-off-by:
Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by:
Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by:
Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by:
Pawel Gadzinski <pgadzinski@nvidia.com> * fixes Signed-off-by:
Pawel Gadzinski <pgadzinski@nvidia.com> * sofmtax_fusion -> softmax_fusion_type Signed-off-by:
Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by:
Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by:
Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by:
Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by:
Pawel Gadzinski <pgadzinski@nvidia.com> --------- Signed-off-by:
Pawel Gadzinski <pgadzinski@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
-
- 14 Oct, 2025 1 commit
-
-
Kshitij Lakhani authored
* Add BRCM support when creating a test mask for fused attn Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Add support for BRCM to correctly generate the mask needed for calculating the seqlens and offsets for THD Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Skip drop=0 and no_bias case for BRCM as cuDNN does not suport this Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Skip BRCM test cases where max_seqlen_q > max_seqlen_kv Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Refactor the segment id run length code for BRCM seqoffset and seqlens calculations Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Fix the drop inequality skip condition in fused attn Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * nit: Adjust the BRCM id name in the test to make it consistent Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Fix the brcm mask condition. Fix the condition for cross atnn type pattern to only apply for brcm Change the num segments per sequence to 3 instead of 2 Reduce one test pattern data size and make it such that it triggers brcm Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix lint errors Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Fix incorrectly changed dtype to numpy bool_ rather than native python bool Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Restore the numsegments to earlier value Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Add example for THD BRCM Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> --------- Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
-
- 30 Jul, 2025 1 commit
-
-
jberchtold-nvidia authored
* TE primitive checkpointing policies Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Remove batched gemm policy Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> --------- Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com>
-
- 13 Jun, 2025 2 commits
-
-
Kshitij Lakhani authored
* Add support for Fused Attn MLA head_dim_qk != head_dim_v Modify is_fused_attn_kernel_available() to accept different head_dims for qk and v Modify FusedAttnHelper to accept different head_dims for qk and v and modify assert dims checks in parse_qkv_aval() Modify FusedAttnFwdPrimitive and FusedAttnBwdPrimitive to accept different head_dims for qk and v Modify Fused Attn related cpp and csrc extension API calls to accept different head_dims for qk and v Modify DotProductAttention call() to extract head dims separately for qk and v Modify the FusedAttn Tests to accommodate for API changes in FusedAttn API Add test case for head_dim_qk != head_dim_v (failing) Modify the baseline JAX appropriately to reshape the output vector based on v dims and not q dims Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix context dims in general DPA in test_fused_attn Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> * Fix dim for output tensor by replacing with v head dim rather than q head dim Add test cases for jax fused attn where head_dim_qk != head_dim_v for a combination of data types and attention type Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> * Modify the fused attn jax unit test case for head dim qk != head dim v Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> * Use new FusedAttnRunner function signature for separate hidden dim for qk and v in Fused Attn distributed tests Code clean up Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> * Fix usage of is_fused_attn signature in distributed tests Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> * Remove unnecessary assert Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> --------- Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
-
Charlene Yang authored
* add support for head dim > 128 Signed-off-by:
Charlene Yang <charleney@nvidia.com> * remove debugging Signed-off-by:
Charlene Yang <charleney@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * raise tols slightly to tolerate 1/2048 mismatches Signed-off-by:
Charlene Yang <charleney@nvidia.com> * fix is_training for test_te_layer Signed-off-by:
Charlene Yang <charleney@nvidia.com> * add bprop support for blackwell Signed-off-by:
Charlene Yang <charleney@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor tweak for format Signed-off-by:
Charlene Yang <charleney@nvidia.com> * fix backend selection results Signed-off-by:
Charlene Yang <charleney@nvidia.com> * bump sm100 to sm100+ Signed-off-by:
Charlene Yang <charleney@nvidia.com> * add sq=1 test for MLA Signed-off-by:
Charlene Yang <charleney@nvidia.com> * enable sq=1 for bprop Signed-off-by:
Charlene Yang <charleney@nvidia.com> * minor tweak in comments Signed-off-by:
Charlene Yang <charleney@nvidia.com> * fix head_dim logic and remove pytest skip Signed-off-by:
Charlene Yang <charleney@nvidia.com> * add FE fix for d>128 Signed-off-by:
Charlene Yang <charleney@nvidia.com> * update FE again to take in small fixes Signed-off-by:
Charlene Yang <charleney@nvidia.com> * add cuDNN version info in L0 tests Signed-off-by:
Charlene Yang <charleney@nvidia.com> * increase tols for Unfused + large dim Signed-off-by:
Charlene Yang <charleney@nvidia.com> * Revert "add cuDNN version info in L0 tests" This reverts commit 3e1b426ca5319a2c0540b9e73bba7047d0e583e5. Signed-off-by:
Charlene Yang <charleney@nvidia.com> * fix tols for Unfused Signed-off-by:
Charlene Yang <charleney@nvidia.com> --------- Signed-off-by:
Charlene Yang <charleney@nvidia.com> Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
-
- 31 Mar, 2025 1 commit
-
-
Michael Goldfarb authored
Add fast path for causal masking with segment IDs. Signed-off-by:Michael Goldfarb <mgoldfarb@nvidia.com>
-
- 05 Mar, 2025 1 commit
-
-
Kirthi Shankar Sivamani authored
* Fix wheel install after src install Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix JAX imports Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * switch order of dirs for finding so Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Use existing dir src build Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix lint Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 03 Mar, 2025 1 commit
-
-
Reese Wang authored
* Support THD + ring attention for self attn Signed-off-by:
Reese Wang <rewang@nvidia.com> * Consolidate reorder strategy Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix dataclass frozen issue Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove redundant code Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use AttnBiasType, AttnMaskType, QKVLayout in cpp_extension Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix lint Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refine P2P helper check_supported Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add segment_ids/pos check Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fixup Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add dual chunk swap example Signed-off-by:
Reese Wang <rewang@nvidia.com> * Align different reorder code structure Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Co-authored-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 14 Feb, 2025 1 commit
-
-
Reese Wang authored
Fix issues when mask/sequence_descriptor is None Signed-off-by:
Reese Wang <rewang@nvidia.com> Co-authored-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 24 Jan, 2025 1 commit
-
-
Reese Wang authored
* POC for segment_ids/segment_pos Signed-off-by:
Reese Wang <rewang@nvidia.com> * Change segment_pos position Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use RemainingArgs to solve number of parameters mismatches Signed-off-by:
Reese Wang <rewang@nvidia.com> * Test mask_descriptor for accomendating different mask representations Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix bugs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use descriptor in bwd Signed-off-by:
Reese Wang <rewang@nvidia.com> * Primitives only accepts pure jnp array Signed-off-by:
Reese Wang <rewang@nvidia.com> * segment_ids/pos support POC Signed-off-by:
Reese Wang <rewang@nvidia.com> * Move seqlens/offsets generation to mask descriptor Signed-off-by:
Reese Wang <rewang@nvidia.com> * Rename MaskDescriptor to SequenceDescriptor Signed-off-by:
Reese Wang <rewang@nvidia.com> * Generalize get_seqlens_and_offsets Signed-off-by:
Reese Wang <rewang@nvidia.com> * Utilize sequence desc on FA bwd Signed-off-by:
Reese Wang <rewang@nvidia.com> * Migrate to new API Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add docstrings Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove small inputs and test different input format Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix lint Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix seed shardings Signed-off-by:
Reese Wang <rewang@nvidia.com> * Optimize sequence converting overhead Signed-off-by:
Reese Wang <rewang@nvidia.com> * Optimize seq_offsets calculation Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix up Signed-off-by:
Reese Wang <rewang@nvidia.com> * fix lint Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix conflicts Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove reduntant line Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 08 Jan, 2025 1 commit
-
-
Reese Wang authored
* Fix SWA mask for THD and forcing seqlen_kv >= seqlen_q for SWA Signed-off-by:
Reese Wang <rewang@nvidia.com> * Generalize sliding window mask Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix pylint Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 02 Jan, 2025 1 commit
-
-
Kirthi Shankar Sivamani authored
Signed-off-by:Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 17 Dec, 2024 1 commit
-
-
Reese Wang authored
* Add util functions to attn_mask_type Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add util functions to qkv_layout Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix THD cross reference code Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove explicit segment_pad, encoding it to segment_ids Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add jax.jit, replace _token with segment_ids, rename bias shape enum Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add comment for make_mask Signed-off-by:
Reese Wang <rewang@nvidia.com> * Clean code Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add doc strings for the added functions Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove cache for fa deterministic which causes UT failed Signed-off-by:
Reese Wang <rewang@nvidia.com> * Rename fixture to avoid conflict Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 11 Nov, 2024 1 commit
-
-
Ming-Xu Huang authored
* Implement ring attention primative for Jax. Signed-off-by:
Michael Goldfarb <mgoldfarb@nvidia.com> Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Michael Goldfarb <mgoldfarb@nvidia.com> Signed-off-by:
Ming Huang <mingh@nvidia.com> Co-authored-by:
Michael Goldfarb <mgoldfarb@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
-
- 04 Nov, 2024 1 commit
-
-
Md Fahim Faysal Khan authored
Exposed context parallel params to DPA api Signed-off-by:
Md Fahim Faysal Khan <mdfahimfaysa@nvidia.com> Signed-off-by:
Michael Goldfarb <mgoldfarb@nvidia.com> --------- Signed-off-by:
Md Fahim Faysal Khan <mdfahimfaysa@nvidia.com> Signed-off-by:
Michael Goldfarb <mgoldfarb@nvidia.com> Co-authored-by:
Michael Goldfarb <mgoldfarb@nvidia.com>
-
- 24 Oct, 2024 1 commit
-
-
Michael Goldfarb authored
[JAX] Fix correctness of JAX fused attention with CP and improve numerics check in unit tests (#1282) Fix correctness of JAX fused attention with CP. Signed-off-by:Michael Goldfarb <mgoldfarb@nvidia.com>
-
- 15 Oct, 2024 1 commit
-
-
Michael Goldfarb authored
Update test to check support for context parallel attention. Signed-off-by:Michael Goldfarb <mgoldfarb@nvidia.com>
-
- 10 Oct, 2024 1 commit
-
-
Hua Huang authored
* Expose JAX sliding window attn API Signed-off-by:
Hua Huang <huah@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * No SWA in context parallel; fix RNG seed in test Signed-off-by:
Hua Huang <huah@nvidia.com> * Handle SAW API discrepancy in cuDNN and Python Signed-off-by:
Hua Huang <huah@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add SAW API for flax, all tests passed Will update tests/jax/test_praxis_layers.py next Signed-off-by:
Hua Huang <huah@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_praxis_layers.py for SWA, test passed Signed-off-by:
Hua Huang <huah@nvidia.com> * Use tuple window_size; update for PR #1212 Signed-off-by:
Hua Huang <huah@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add and adjust some pytest.skip Signed-off-by:
Hua Huang <huah@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revised following Reese Wang's comments Still need further debugging: FAILED test_fused_attn.py::TestFusedAttn::test_backward[NO_SWA-DROP_0.0-4-128-256-16-16-64-BF16-CROSS-KV_PACKED-NO_MASK-NO_BIAS] - AssertionError: FAILED test_fused_attn.py::TestFusedAttn::test_backward[NO_SWA-DROP_0.0-4-128-256-16-16-64-BF16-CROSS-KV_PACKED-NO_MASK-POST_SCALE_BIAS-1HSS] - AssertionError: FAILED test_fused_attn.py::TestFusedAttn::test_backward[NO_SWA-DROP_0.0-4-128-256-16-16-64-BF16-CROSS-SEPARATE-NO_MASK-NO_BIAS] - AssertionError: FAILED test_fused_attn.py::TestFusedAttn::test_backward[NO_SWA-DROP_0.0-4-128-256-16-16-64-BF16-CROSS-SEPARATE-NO_MASK-POST_SCALE_BIAS-1HSS] - AssertionError: These errors does not exist in the previous commit Signed-off-by:
Hua Huang <huah@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix no-SWA test case errors in previous commit Signed-off-by:
Hua Huang <huah@nvidia.com> * Add Padding mask w/ sliding windows sanity tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use float32 for the reference code softmax calculation Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Hua Huang <huah@nvidia.com> Signed-off-by:
Reese Wang <rewang@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by:
Reese Wang <rewang@nvidia.com>
-
- 17 Sep, 2024 1 commit
-
-
Michael Goldfarb authored
Implementation of context parallel fused attention using all-gather. Signed-off-by:Michael Goldfarb <mgoldfarb@nvidia.com>
-
- 03 Jul, 2024 1 commit
-
-
Reese Wang authored
* Integrate experimental ragged offset Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use per sequence based offsets Signed-off-by:
Reese Wang <rewang@nvidia.com> * Format Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove v/o_seq_offsets Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add FP16 sanity tests and remove forward tests from the automatically run tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance input checks Signed-off-by:
Reese Wang <rewang@nvidia.com> * Separate fused attn to 2 differnt APIs and add the docs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add experimental to the docs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix lint Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add runtime segments check Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove finished TODO Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 14 Jun, 2024 1 commit
-
-
Kirthi Shankar Sivamani authored
* Apply formatting Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Apply formatting Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 13 Jun, 2024 1 commit
-
-
Phuong Nguyen authored
* Splitted cpp_extensions.py, renamed mlp.py and fused_attn.py Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * fixed import in tests Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 06 Jun, 2024 1 commit
-
-
Kirthi Shankar Sivamani authored
Cleanup Signed-off-by:Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 22 Mar, 2024 1 commit
-
-
Reese Wang authored
* Remove unused headers Signed-off-by:
Reese Wang <rewang@nvidia.com> * Unify the fused attn workspace size cpp code Signed-off-by:
Reese Wang <rewang@nvidia.com> * Reduce the skipped cases Signed-off-by:
Reese Wang <rewang@nvidia.com> * Rename self/cross attention to qkvpacked/kvpacked Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update attention mask docs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refine the attn mask implementations Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 06 Mar, 2024 1 commit
-
-
George Karpenkov authored
Bias and seed can both be None, type checking is failed otherwise. Signed-off-by:George Karpenkov <george@metaworld.me>
-
- 22 Feb, 2024 1 commit
-
-
Reese Wang authored
* Refine MHA API Signed-off-by:
Reese Wang <rewang@nvidia.com> * Reuse func from the flax Signed-off-by:
Reese Wang <rewang@nvidia.com> * DPA draft Signed-off-by:
Reese Wang <rewang@nvidia.com> * qkv packed draft Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix test_layer with fused attn Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add attn_bias_type and enhance a few code flow Signed-off-by:
Reese Wang <rewang@nvidia.com> * Move scale_factor from __call__ to init Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the docs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add DPA public API and tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refine docs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refine docs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix conflict Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add qkv separate fused attn Signed-off-by:
Reese Wang <rewang@nvidia.com> * Apply BSHD_BSHD_BSHD format Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove debug log Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add fused attention layer tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add NVTE_FUSED_ATTN docs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fine-grained fused attn settings Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove the default value of num_attetnion_head and head_dim Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add teardown for fused attn env Signed-off-by:
Reese Wang <rewang@nvidia.com> * Unify the Optional notation Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix Pre/Post scale bias comments Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add no_mask tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add checkpoint_name for fused attn Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix the fused attn batcher Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 16 Jan, 2024 1 commit
-
-
zlsh80826 authored
* Support num_gqa_groups arguments Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add GQA support on the JAX bridge code Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix the kv stride of the arbitrary backend Signed-off-by:
Reese Wang <rewang@nvidia.com> * Complete rewrite fused attention tests and add GQA coverage Signed-off-by:
Reese Wang <rewang@nvidia.com> * Support unfused GQA Signed-off-by:
Reese Wang <rewang@nvidia.com> * Calculate seqlen before the primitive for the better perf Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add GQA layer tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Apply code style checks for te_jax Signed-off-by:
Reese Wang <rewang@nvidia.com> * Apply code style checks for tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add num_gqa_groups doc Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refine the qkv_type Signed-off-by:
Reese Wang <rewang@nvidia.com> * Correct the variable naming Signed-off-by:
Reese Wang <rewang@nvidia.com> * Handle Max512 CAUSAL Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add WAR for the latest jax image Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 03 Jan, 2024 1 commit
-
-
Przemyslaw Tredak authored
Signed-off-by:Przemek Tredak <ptredak@nvidia.com>
-
- 07 Dec, 2023 1 commit
-
-
cyanguwa authored
* Integrate cuDNN frontend v1 to fused attention and miscellaneous fixes Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix jax/paddle for unit tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix jax/pytorch lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * simplify stride generation Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix and/or logic in get_backend Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix flag_max512 and test_numerics Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove v.contiguous() since get_qkv_layout covers it Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * skip fp8 tests for sm89 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * further fix jax CI Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix jax CI Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert mask type to comma-separated list Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix last two commits Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * integrate v1/pre-release-5 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * cleanup prerelease5 integration and fix FA2.1 commit Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * force dropout to 0 if not training Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix Jax CI Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * testing bias/alibi and padding+causal; add alibi to unfused DPA Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * set flag_arb to false when non determinism is not allowed Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * followup on prev commit; remove redundant python env var setting Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * WIP: minor tweaks for tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * prepare for tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix determinism logic for fused attn Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add bias to bwd Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix gpt_checkpointing/dpa_accuracy problem Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix some seg fault issues Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add failure notes Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove use of non-deter var for backend selection Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fix for lint and CI Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix workspace size in bwd and uncomment bias test Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix get_alibi and remove check_support Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update tests status Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove workspace_opt from FADescriptor_v1 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * disable arbitrary backend + post scale bias in Jax; waiting on PR 525 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up bhsd order Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * swap bias/rng_state order in aux_ctx_tensor and add bias to aux_ctx_tensor in _qkvpacked/_kvpacked API Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove support for padding_causal + cross for max512 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * change alibi bias to float32 for bias_1_4/5 tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * further clean up tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix thd fwd output shape for FlashAttention and add backend info for DPA Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix definition of workspace limit when dbias is present Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * further tweak DP_WORKSPACE_LIMIT definition Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * disallow alibi+no_mask for sdpa flash and update alibi tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update jax/paddle after PR525 and fix DP_WORKSPACE_LIMIT for dbias Jax tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * disable dbias for non-hopper archs Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix layernorm lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remode unused arg for lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove build dir in setup.py Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * change selection logic to prefer fused attn on sm90 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix distributed jax test Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix h and s order in header Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update to cudnn fe v1 branch Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove manual setting of workopt path due to dbias after v1 update Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix paddle CI Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add post_scale_bias and alibi to sdpa flash support matrix Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix support matrix in header files Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * move headers back to .cu and change seed/offset to int64 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update Megatron commit in L1 test and remove all prints in fused attn test Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix L1 Megatron test Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fp8 arg in L1 Megatron script Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * print only when debug flag is on Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove checkpointing loading to avoid loading other tests results Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by:
cyanguwa <8636796+cyanguwa@users.noreply.github.com>
-
- 04 Dec, 2023 1 commit
-
-
zlsh80826 authored
Add checkpoint_name Signed-off-by:Reese Wang <rewang@nvidia.com>
-
- 01 Dec, 2023 1 commit
-
-
zlsh80826 authored
* Add rng_state output for cross fused attention Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add rng_state and output for the flash attention backward Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add bias for the jax cross attn API Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix a minor bug Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add bias in the backward for the arbitrary fused attn backend Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 20 Nov, 2023 1 commit
-
-
zlsh80826 authored
* Remove assertion for NO_MASK Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix JAX distributed unit tests name Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 14 Nov, 2023 1 commit
-
-
Ming-Xu Huang authored
* Refactor sharding.py for the further custom_partitioning migration Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating both FWD and BWD of LayerNorm/RMSNorm from xmap to custom_partitioning. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating both FWD and BWD of all kinds of softmax from xmap to custom_partitioning. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix the wrong order of parameters to LN/RMSN bwd in ln_mlp_fp8. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * WAR to LN/RMSN_fp8 before migrating to CP. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix the wrong order of parameters of bwd of LN/RMSN_fp8. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Following review feedback to modify Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Force the hidden dim in Norm ops to no sharding and add warning msg. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Reuse fwd_rule in VJP functions Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating both FWD and BWD of self-fused-attn from xmap to custom_partitioning. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating both FWD and BWD of cross-fused-attn from xmap to custom_partitioning. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * add gelu and dgelu. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Reuse fwd_rule in VJP functions for attentions Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Apply native FP8 Dtypes to fp8.py Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating cast_and_transpose from xmap to custom_partitioning Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating transpose from xmap to custom_partitioning Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Apply XLA pattern match to perform FP8 GEMM. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * migrate layernorm_fp8 to custom_partitioning. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Unify code style Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Extend supported of Transpose with FP8 Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Implementing layernorm_fp8_dot based on migrated custom calls. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Renaming variables and publish NVTE_FP8_COLLECTION_NAME Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Replace Q/DQ custom calls with native XLA implementations Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * migrate gelu_fp to custom_partitioning. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Miner fix Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Support custom calls with mutli-dims Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Support gerneral dot indices in _fp8_dot_impl Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Implementing layernrom_geglu_fp8_mlp Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Remove GEMM custom calls Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Remove xmap related code Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix typo and add query-function to FP8MetaPackage Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix some bugs of custom calls Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix CT's bugs Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Update UTs/eaxmaples to adapt to the API changes. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Unify kernel initilization in MLP. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Modifing with code review's feedback Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Update README and Add deprecating warning to *ShardingType Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Canonicalize the dtype Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding assertion for non-supported batch dims. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding doc/examples to _multidim_transpose Signed-off-by:
Ming Huang <mingh@nvidia.com> * Set FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Set FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Apply dtype-based rtol/atol to UTs Signed-off-by:
Ming Huang <mingh@nvidia.com> * Deprecate QKV_INTERLEAVED enum Signed-off-by:
Ming Huang <mingh@nvidia.com> * Skip test_distributed_custom_ops.py Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix the wrong sharding of bias in SelfAttn Signed-off-by:
Ming Huang <mingh@nvidia.com> * WAR to fix the wrong cu_seqlen of MHA when DP/FSDP enabled Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding distributed ops unit-tests Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding license to test_distributed_* Signed-off-by:
Ming Huang <mingh@nvidia.com> * Follow review feedback to modify Signed-off-by:
Ming Huang <mingh@nvidia.com> * Use total bytes involved in collective ops as criteria. Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> Co-authored-by:
Donglin Yang <dongliny@nvidia.com>
-
- 13 Nov, 2023 1 commit
-
-
zlsh80826 authored
[C/JAX] Support more mask types for the arbitrary seqlen kernels and minor changes of JAX bias (#469) * Move bias to float32 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enable varlen Signed-off-by:
Reese Wang <rewang@nvidia.com> * Increase neg infinity abs values Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enable varlen tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove unnecessary code Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix lint Signed-off-by:
Reese Wang <rewang@nvidia.com> * Support variable sequence length after cuDNN 8.9.6 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use unique_ptr instead of shared_ptr Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add a new mask type: PADDING_CAUSAL_MASK Signed-off-by:
Reese Wang <rewang@nvidia.com> * Support flash padding mask after 8.9.6 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the Max512 handling for causal masking and add the related tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update the fused attn support lists Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove padding_aware from the caching Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix libtransformer.so issue Signed-off-by:
Reese Wang <rewang@nvidia.com> * Reduce the pad ratio tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix a bug with cuDNN 8.9.5 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Release backend resource after the module level unit test Signed-off-by:
Reese Wang <rewang@nvidia.com> * Clean the jax live arrays before running the unit tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix too-few-public-methods lint Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 08 Nov, 2023 1 commit
-
-
zlsh80826 authored
* Deprecate QKV_INTERLEAVED use in JAX Signed-off-by:
Reese Wang <rewang@nvidia.com> * Deprecate QKV_INTERLEAVED use in Paddle Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance qkv enum mappings Signed-off-by:
rewang <rewang@nvidia.com> * Fix LD_LIBRARY_PATH issue Signed-off-by:
rewang <rewang@nvidia.com> * Arbitrary seqlen kernels only support self attention currently Signed-off-by:
rewang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Signed-off-by:
rewang <rewang@nvidia.com>
-
- 20 Oct, 2023 1 commit
-
-
Alp Dener authored
fixed incorrect of extend_fsdp_sharding_meta() in cross_fused_attn() Signed-off-by:Alp Dener <adener@nvidia.com>
-
- 25 Aug, 2023 1 commit
-
-
zlsh80826 authored
* Fused attention kernel only supports sm80 and sm90 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update transformer_engine/jax/csrc/modules.cpp Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * arbitary fused kernel supports sm86/sm89 after 8.9.3 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Skip sm70 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Forward is_fused_attn_kernel_available to cpp backend Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove cpp is_fused_attn_available API Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com>
-
- 09 Aug, 2023 1 commit
-
-
Ming-Xu Huang authored
* Initially commit for FSDP Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding support to fsdp xmap sharding Signed-off-by:
Ming Huang <mingh@nvidia.com> * Specify WeightHParamsCollection of fp8 meta. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Support partial FP8 custom calls with FSDP. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding amax reduction on the fsdp mesh dim. Signed-off-by:
Ming Huang <mingh@nvidia.com> * clean code Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix the wrong batch axis in logic_axis_rules and add sharding_constraint to BMM1 Signed-off-by:
Ming Huang <mingh@nvidia.com> * Support FSDP in fMHA. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix missing all-reduce of wgrads along FSDP axis. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Change default value of fsdp_axis_name to for aligning with others Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix RuntimeError: with_sharding_constraint requires a non-empty Signed-off-by:
Ming Huang <mingh@nvidia.com> * Slightly changes (review feedback) Signed-off-by:
Ming Huang <mingh@nvidia.com> * Removed unnecessary comments Signed-off-by:
Ming Huang <mingh@nvidia.com> * Mergeing input_dp_dim into weight_fsdp_dim_map Signed-off-by:
Ming Huang <mingh@nvidia.com> * Update transformer_engine/jax/sharding.py Signed-off-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com>
-