- 06 Dec, 2025 1 commit
-
-
Kshitij Lakhani authored
* Add generic stripe_height support for load balancing Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Fix imports in test for deprecated jax.experimental.pjit Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Add test case for stripe_height greater than 1. Add stripe_height arg to reordering methods Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> * Add Striped 1 and 4 test cases. Refactor the Load Balancing test case. Fix the incorrect shape in striping inverser reordering Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> * Modify test code for CP + AG + THD + stripe height greater than 1 Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> * Add stripe_height arg to fused attn and fused attn fwd API. Add appropriate mask checks for AG+THD+CP and pick BRCM to be executed per rank. Add Fused Attn Primitive for CP + THD +AG + Striping. Add a method to reorder and all gather segment ids and offsets for kv Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> * TMP: Throwaway testing commit Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> * Add comments in primitive registration process Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> * TMP: Throwaway test commit Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Undoing incorrect rebase/merge leftovers Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> * TMP: Throwaway test commits Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> * Add support for calculating q and kv seqlens and offsets per rank for CP+THD+AG+SW+Striped>1 primitive Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> * Augment jax primitive register code comments Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> * Fix the array sizes and padding values returned for seqlens and offsets to fit what the fused attn primitive non cp computation Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Add support in new primitive for softmax_offset related changes. Put in missing primitive registering line in again. Increase the seqoffsets arrays lengths by 1 Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com> * Add new set of helper functions for seqlens and seqoffsets fo AG+THD+CP+Stripe>1 which accounts for batching and seq offsets size b+1 Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Add backward primitive for CP+THD+AG+Striped>1 Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Modify tests for backward primitive for CP+THD+AG+Striped>1 Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Move stripe_height along with other static args in fused_attn_bwd rule. Fix typo in CP+AG+TH+Striped>1 primitive Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com> * Code clean up: remove older version for calculating seqlens and offsets for CP+AG+THD+striped>1 primitive Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Add test for CP+THD+AG+Striped>1 Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Fix missing var Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add SWA tests for AG+Striped>1+CP+THD+SWA Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com> * Restoring test code Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com> * Remove assert preventing SWA code path in CP+AG+Striped primitive Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com> * Parametrize num_segments_per_seq in tests Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Clean up test code Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> Clean up test code in TE common Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> Clean up debug statements Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Rename stripe_height to stripe_size Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Code clean up and add additional comments Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> nit: Apply suggestions from code review Co-authored-by:
greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by:
Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Fix type on fused attn tests Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com> * Fix seqoffsets length to be passed onto FusedAttn primitive as it is b and not b+1 needed by cuDNN Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com> * Remove commented code Co-authored-by:
greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by:
Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Fix linting issues Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> Fix incorrect greptile change Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip THD test cases for CP + AG + Dual chunk. Skip BSHD cases for CP + AG + Striped>1. Correct the layout and shapr parameters passed to the tests Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com> * Pass stripe_size explicitly for ring attn tests for THD cases Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com> * Remove TODO Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com> * Explicitly fail if THD + AG is being used with a non padding causal mask Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * nit: Correct the ID for the test dist fused attn tests to account for cp*2 which is done under the hood Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Set num_segments_per_seq defaults to None instead of 0 Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Augment comments. Add ValueError for stripe_size=0 Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Test only 1 num_segments_per_seq combination for CP+AG+THD+Striped>1+SWA instead of 2. Modify the num segments and window size to easily to debug values Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Default stripe_size to None instead of 0. Modify stripe_size check for <=0 instead of ==0 Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove incorrectly added file Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Explicitly pass zero sized arrays for seg ids and pos in the CP + AG + Striped primitive rather than using the seqlens or the offsets as placeholders Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Fix linting errors Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add a deep dive doc for CP+THD+AG+Stripe>1+SWA regarding design considerations and decisions Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> * Put docs and pngs into it's separate dir Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> * Replace png screenshots with markdown coe blocks for the attention patterns. Remove unecessary pngs Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> * Add doc file to index.rst. Fix grammatical errors Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> --------- Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> Signed-off-by:
Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com> Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> Co-authored-by:
Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> Co-authored-by:
Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
-
- 14 Nov, 2025 1 commit
-
-
jberchtold-nvidia authored
* Refactor to avoid storing a global quantization config so direct recipe passing works as intended Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * fix use_split_accumulator for current scaling recipe Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * fix tests that pass direct recipe and were missing quantize meta set Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Revert "fix use_split_accumulator for current scaling recipe" This reverts commit a74ab7df812ec0a069b1bdd208debb93ec25a900. Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * fix ci failures Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Fix amax_history post_init Co-authored-by:
greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by:
jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> * Update transformer_engine/jax/quantize/quantizer.py Co-authored-by:
greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by:
jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix ci failures Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * fix ci issue Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * address comments Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * make recipe assertion classes in test_recipe_characteristics not inherit from unittest.TestCase Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> --------- Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> Signed-off-by:
jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Co-authored-by:
greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
-
- 23 Sep, 2025 1 commit
-
-
Ming-Xu Huang authored
* Adding Amax Primitive and related args. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Enable local-amax for current-scaling and optionally run AR aross FSDP/TP/SP. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding doc for Amax Primitive. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix the function name conflict. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Modification as feedback suggested. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix errors from lint. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix the wrong amax-scope in the bwd. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Added more description for amax-scope Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix the wrong attribute name. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Keep dim for AmaxCalcuation. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Remove keepDim and add shardy_rule Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix shardy_rule Signed-off-by:
Ming Huang <mingh@nvidia.com> * Remove extra-collective bytes from ref_coll_count due to local amax. Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> Co-authored-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 22 Sep, 2025 1 commit
-
-
Phuong Nguyen authored
* remove import jax.extend.ffi Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 09 Sep, 2025 1 commit
-
-
Phuong Nguyen authored
* add swizzle in jax Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * added outer_impl Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * clean up FFI Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 27 Aug, 2025 1 commit
-
-
jberchtold-nvidia authored
* Decouple recipe and scaling mode Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Expose global QuantizeConfig instance as a getter Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Format and lint Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Rename UsageType to TensorSource Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Update test_layer.py Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> --------- Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> Signed-off-by:
jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
-
- 08 Aug, 2025 1 commit
-
-
Phuong Nguyen authored
* enabled TE GEMM for all recipes Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * add warnings Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
-
- 06 Aug, 2025 1 commit
-
-
Phuong Nguyen authored
Revert "[JAX] Disable TE Norm Custom Calls (#1993)" This reverts commit 6c970612 . --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 05 Aug, 2025 1 commit
-
-
Phuong Nguyen authored
Disable Norm custom calls Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com>
-
- 24 Jul, 2025 1 commit
-
-
Phuong Nguyen authored
* add manage_primitives() helper * disable GEMM primitives for non-MXFP8 recipes * implement the NVTE_JAX_CUSTOM_CALLS + deprecate NVTE_JAX_CUSTOM_CALLS_RE * replace NVTE_JAX_CUSTOM_CALLS_RE with NVTE_JAX_CUSTOM_CALLS in TE tests and examples * fix use_jax_gemm contextmanager Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 22 May, 2025 1 commit
-
-
jberchtold-nvidia authored
Make primitive names more granular for better disabling granularity Signed-off-by:Jeremy Berchtold <jberchtold@nvidia.com>
-
- 14 Apr, 2025 1 commit
-
-
Johannes Reifferscheid authored
* Add experimental Shardy support. Production use is not yet recommended. --------- Signed-off-by:Johannes Reifferscheid <jreiffers@nvidia.com>
-
- 01 Apr, 2025 1 commit
-
-
Phuong Nguyen authored
* refactor + mxfp8 * added grouped gemm * rename linear to dense * added cublas init phase for groupedGemm * relax the tol of test encoder multiprocessing mxfp8 by 0.001 Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> Co-authored-by:
Hua Huang <huah@nvidia.com> Co-authored-by:
Jeremy Berchtold <jberchtold@nvidia.com>
-
- 02 Jan, 2025 1 commit
-
-
Kirthi Shankar Sivamani authored
Signed-off-by:Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 14 Dec, 2024 1 commit
-
-
Phuong Nguyen authored
* softmax custom calls with correct encapsulates * rm jax deprecated features --------- Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com>
-
- 17 Jul, 2024 1 commit
-
-
Reese Wang authored
* Add enabled() to BasePrimitive * Add layernorm/rmsnorm fallback * Add cast_fp8 fallback * Add transpose/cast_transpose XLA fall back * Act_lu fallback * Add transpose fallback * Add softmax fallback * Unify the use of _cast_fp8 * Add tests for NVTE_JAX_CUSTOM_CALLS_RE --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Co-authored-by:
Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
-
- 14 Jun, 2024 1 commit
-
-
Kirthi Shankar Sivamani authored
* Apply formatting Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Apply formatting Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 13 Jun, 2024 1 commit
-
-
Phuong Nguyen authored
* Splitted cpp_extensions.py, renamed mlp.py and fused_attn.py Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * fixed import in tests Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-