- 19 Dec, 2025 1 commit
-
-
jberchtold-nvidia authored
* Handle meshs set with jax.set_mesh Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
-
- 17 Dec, 2025 1 commit
-
-
jberchtold-nvidia authored
* Tutorial for integration te/jax quantization into an existing framework Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * add todos Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * support nvfp4 sr rng key, move wrapper module into TE itself, fix bfloat16 cast Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * update docstrings Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Fix QKV proj and out proj in Flax example transformer layer Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Use fused attention in quickstart_jax example Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remat policy Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * add tutorial to docs Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * update title Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * remove unused dtype from TE DPA module Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Fix notebook title Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Fix lint Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Add explanation of flax module wrapper Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
-
- 25 Nov, 2025 1 commit
-
-
Phuong Nguyen authored
* allow dp + fsdp and fixed sr_rng_state partitioning Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleanup for lint test Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
-
- 31 Oct, 2025 1 commit
-
-
jberchtold-nvidia authored
* Fix mesh resource requirement when no mesh Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * do not require meshresource if all axes are manual axes Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * remove abstract_mesh is None check Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> --------- Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com>
-
- 22 Oct, 2025 1 commit
-
-
jberchtold-nvidia authored
* [JAX] Support recipe flags for disabling SR, RHT, and 2D quantization Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * lint Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Fix issue with SR state being erased due to pytree handling of NVFP4Quantizer Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Add test for SR state preservation across VJP boundaries Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Fix sharding of SR rng state Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * lint Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * update tolerances slightly now that SR is enabled Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * lint Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Use hashlib for deterministic hashes across runs for SR Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * rename uses_rht on scaled tensors to has_applied_rht Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * add assert Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Move decision of whether to use RHT into helper.py and add dedicated RHT tests Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * lint Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * fix use_rht attr usage Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * fix pure-jax rht usage criteria Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Adjust tolerances after rebase Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> --------- Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com>
-
- 06 Oct, 2025 1 commit
-
-
Phuong Nguyen authored
* not fuse bias for output all reduction case + unit tests Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * norm to reduce dgamma along tpsp as well Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * clean up tests Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * fix test_distributed_layernorm byte counts Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * increase tols for jax_gemm Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 02 Oct, 2025 1 commit
-
-
jberchtold-nvidia authored
Fix shard map issue Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> Co-authored-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 27 Sep, 2025 1 commit
-
-
Phuong Nguyen authored
* init cgemm + unit tests * UB bootstrap with NCCL, no MPI dependency * add NVLINK-P2P check + error message * skip tests if no NVLINK available * use std::vector to store ncclComm_t * update misuse of TP warning Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 27 Aug, 2025 2 commits
-
-
Phuong Nguyen authored
* add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
jberchtold-nvidia authored
Delay MeshResource validation until first usage Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> Co-authored-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 26 Aug, 2025 1 commit
-
-
Phuong Nguyen authored
* clean up sharding Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * added tpsp_resource Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * update tests Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * rework test for MeshResource Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * add mesh_resource into fp8_autocast in test_helper.py Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 20 Aug, 2025 1 commit
-
-
jberchtold-nvidia authored
[JAX] Error checking for mesh resource and update GemmPrimitive to use global_mesh_resource().fsdp_resource (#2088) * Enforce global MeshResource is set Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Use global_mesh_resource().fsdp_resource in gemm primitive Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Update tests Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Update gemm.py Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Update test_layer.py Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> --------- Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com>
-
- 13 Aug, 2025 2 commits
-
-
Phuong Nguyen authored
* fix pspec check Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleaning Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * add docstring Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * use dict.get() Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * fix lint Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
-
Phuong Nguyen authored
* add manual axis filer to sharding_constraint impl Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * fix lint Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * use abstract_mesh instead of physical_mesh Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * add a comment Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * cleanup Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * clean unused var Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 07 Aug, 2025 1 commit
-
-
Phuong Nguyen authored
* rm batch_dim, sequence_dim, sequence_parallel_output Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * rm lhs_quantized_colwise and rhs_quantized_colwise Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * rm unnecessary transpose_batch_sequence arg from some modules Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 06 Aug, 2025 1 commit
-
-
Phuong Nguyen authored
* remove dot1_output_axes Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 24 Jul, 2025 1 commit
-
-
Alp Dener authored
[JAX] Fixing GemmPrimitive partitioning rules to handle tensor-parallelism correctly for sequence-parallel inputs (#1980) * updated GemmPrimitive partitioning rules to explicitly control all-reduce vs. reduce-scatter for sequence-parallelism Signed-off-by:
Alp Dener <adener@nvidia.com> * corrected handling of FSDP sharding for the RHS operand Signed-off-by:
Alp Dener <adener@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use correct logical axes variable to identify sequence-parallel dim in LayerNormDenseGeneral Signed-off-by:
Alp Dener <adener@nvidia.com> * fixed linting issues Signed-off-by:
Alp Dener <adener@nvidia.com> * added assert on sequence-parallel options when GemmPrimitive is disabled Signed-off-by:
Alp Dener <adener@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by:
Alp Dener <adener@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
-
- 16 Jul, 2025 1 commit
-
-
jberchtold-nvidia authored
* Support flax sharding constraints Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Add warning for deprecated TE logical axes Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Update examples Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> --------- Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com>
-
- 29 May, 2025 1 commit
-
-
Hua Huang authored
* Support SWA in CP Ring Attn THD striped sharding Signed-off-by:
Hua Huang <huah@nvidia.com> * Add some comments; move check to _FusedAttnCPWithP2PHelper.check_supported() Signed-off-by:
Hua Huang <huah@nvidia.com> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Remove unused check Signed-off-by:
Hua Huang <huah@nvidia.com> --------- Signed-off-by:
Hua Huang <huah@nvidia.com>
-
- 16 May, 2025 1 commit
-
-
jberchtold-nvidia authored
* [JAX] Update flax module param initialization to support logical partitioning axes Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Fix ffn1 intermediate result being replicated Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Lint Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Add documentation and assert when logical_axes=None Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Fix bias in LayerNormMLP flax module Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Fix layer tests to not use nn_partitioning and instead use nn.with_logical_axes Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> --------- Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com>
-
- 04 Apr, 2025 1 commit
-
-
Phuong Nguyen authored
* rename QuantizeAxis to QuantizeLayout, get_layout to get_data_layout, q_axis to q_layout * add fatten_axis option * added gated act to test encoder * sharding constraint fixes * fix padding when flattening first dim needs to be padded * update test sizes so that padding is tested * rm output sharding as it can be done in the flax module * sharding scale_inv for mxfp8 --------- Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com>
-
- 01 Apr, 2025 1 commit
-
-
Phuong Nguyen authored
* refactor + mxfp8 * added grouped gemm * rename linear to dense * added cublas init phase for groupedGemm * relax the tol of test encoder multiprocessing mxfp8 by 0.001 Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> Co-authored-by:
Hua Huang <huah@nvidia.com> Co-authored-by:
Jeremy Berchtold <jberchtold@nvidia.com>
-
- 02 Jan, 2025 1 commit
-
-
Kirthi Shankar Sivamani authored
Signed-off-by:Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 11 Nov, 2024 1 commit
-
-
Ming-Xu Huang authored
* Implement ring attention primative for Jax. Signed-off-by:
Michael Goldfarb <mgoldfarb@nvidia.com> Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Michael Goldfarb <mgoldfarb@nvidia.com> Signed-off-by:
Ming Huang <mingh@nvidia.com> Co-authored-by:
Michael Goldfarb <mgoldfarb@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
-
- 17 Sep, 2024 1 commit
-
-
Michael Goldfarb authored
Implementation of context parallel fused attention using all-gather. Signed-off-by:Michael Goldfarb <mgoldfarb@nvidia.com>
-
- 19 Aug, 2024 1 commit
-
-
Frédéric Bastien authored
Signed-off-by:
Frederic Bastien <fbastien@nvidia.com> Co-authored-by:
Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
-
- 14 Jun, 2024 1 commit
-
-
Kirthi Shankar Sivamani authored
* Apply formatting Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Apply formatting Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 02 Feb, 2024 1 commit
-
-
Ming-Xu Huang authored
* Adding support of sequence parallelism Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding RoPE Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix wrong batch_logical_axes Signed-off-by:
Ming Huang <mingh@nvidia.com> * Rnaming FSDP outer env var Signed-off-by:
Ming Huang <mingh@nvidia.com> * Poring RoPE to Praxis layers. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Porting GeLU + [FP8 Cast]. Signed-off-by:
Ming Huang <mingh@nvidia.com> * WAR to make XLA successfully match FP8 GEMM on FFN1 with GeLU. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Allowing arbitrary dimension of NVShape for the workspace allocation Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding checkpoint_name to fused functions of mlp.py to get better perf with nn.scan. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Modify with review feedback. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix bugs Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix typo. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fixed for lint Signed-off-by:
Ming Huang <mingh@nvidia.com> * Follow review feedback to modify code. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix typo. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Port SP to Praxis Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix an issue when enabling both GQA and RoPE. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Update docs Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com>
-
- 03 Jan, 2024 1 commit
-
-
Przemyslaw Tredak authored
Signed-off-by:Przemek Tredak <ptredak@nvidia.com>
-
- 14 Nov, 2023 1 commit
-
-
Ming-Xu Huang authored
* Refactor sharding.py for the further custom_partitioning migration Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating both FWD and BWD of LayerNorm/RMSNorm from xmap to custom_partitioning. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating both FWD and BWD of all kinds of softmax from xmap to custom_partitioning. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix the wrong order of parameters to LN/RMSN bwd in ln_mlp_fp8. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * WAR to LN/RMSN_fp8 before migrating to CP. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix the wrong order of parameters of bwd of LN/RMSN_fp8. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Following review feedback to modify Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Force the hidden dim in Norm ops to no sharding and add warning msg. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Reuse fwd_rule in VJP functions Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating both FWD and BWD of self-fused-attn from xmap to custom_partitioning. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating both FWD and BWD of cross-fused-attn from xmap to custom_partitioning. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * add gelu and dgelu. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Reuse fwd_rule in VJP functions for attentions Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Apply native FP8 Dtypes to fp8.py Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating cast_and_transpose from xmap to custom_partitioning Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating transpose from xmap to custom_partitioning Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Apply XLA pattern match to perform FP8 GEMM. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * migrate layernorm_fp8 to custom_partitioning. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Unify code style Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Extend supported of Transpose with FP8 Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Implementing layernorm_fp8_dot based on migrated custom calls. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Renaming variables and publish NVTE_FP8_COLLECTION_NAME Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Replace Q/DQ custom calls with native XLA implementations Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * migrate gelu_fp to custom_partitioning. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Miner fix Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Support custom calls with mutli-dims Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Support gerneral dot indices in _fp8_dot_impl Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Implementing layernrom_geglu_fp8_mlp Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Remove GEMM custom calls Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Remove xmap related code Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix typo and add query-function to FP8MetaPackage Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix some bugs of custom calls Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix CT's bugs Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Update UTs/eaxmaples to adapt to the API changes. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Unify kernel initilization in MLP. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Modifing with code review's feedback Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Update README and Add deprecating warning to *ShardingType Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Canonicalize the dtype Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding assertion for non-supported batch dims. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding doc/examples to _multidim_transpose Signed-off-by:
Ming Huang <mingh@nvidia.com> * Set FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Set FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Apply dtype-based rtol/atol to UTs Signed-off-by:
Ming Huang <mingh@nvidia.com> * Deprecate QKV_INTERLEAVED enum Signed-off-by:
Ming Huang <mingh@nvidia.com> * Skip test_distributed_custom_ops.py Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix the wrong sharding of bias in SelfAttn Signed-off-by:
Ming Huang <mingh@nvidia.com> * WAR to fix the wrong cu_seqlen of MHA when DP/FSDP enabled Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding distributed ops unit-tests Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding license to test_distributed_* Signed-off-by:
Ming Huang <mingh@nvidia.com> * Follow review feedback to modify Signed-off-by:
Ming Huang <mingh@nvidia.com> * Use total bytes involved in collective ops as criteria. Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> Co-authored-by:
Donglin Yang <dongliny@nvidia.com>
-
- 30 Aug, 2023 1 commit
-
-
Ming-Xu Huang authored
* [JAX] Fix incorrect sharding when only enable FSDP. Signed-off-by:
Ming Huang <mingh@nvidia.com> * [JAX] Add WAR to memory misaligned issues of LN BWD. Signed-off-by:
Ming Huang <mingh@nvidia.com> * [JAX] Reuse sm_arch for avoiding duplicate code. Signed-off-by:
Ming Huang <mingh@nvidia.com> * [JAX] Support multiple sizes allocation in WorkspaceManager. Signed-off-by:
Ming Huang <mingh@nvidia.com> * [JAX] Use template and ariadic arguments to improve multple sizes allocator. Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com>
-
- 09 Aug, 2023 1 commit
-
-
Ming-Xu Huang authored
* Initially commit for FSDP Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding support to fsdp xmap sharding Signed-off-by:
Ming Huang <mingh@nvidia.com> * Specify WeightHParamsCollection of fp8 meta. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Support partial FP8 custom calls with FSDP. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding amax reduction on the fsdp mesh dim. Signed-off-by:
Ming Huang <mingh@nvidia.com> * clean code Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix the wrong batch axis in logic_axis_rules and add sharding_constraint to BMM1 Signed-off-by:
Ming Huang <mingh@nvidia.com> * Support FSDP in fMHA. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix missing all-reduce of wgrads along FSDP axis. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Change default value of fsdp_axis_name to for aligning with others Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix RuntimeError: with_sharding_constraint requires a non-empty Signed-off-by:
Ming Huang <mingh@nvidia.com> * Slightly changes (review feedback) Signed-off-by:
Ming Huang <mingh@nvidia.com> * Removed unnecessary comments Signed-off-by:
Ming Huang <mingh@nvidia.com> * Mergeing input_dp_dim into weight_fsdp_dim_map Signed-off-by:
Ming Huang <mingh@nvidia.com> * Update transformer_engine/jax/sharding.py Signed-off-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com>
-
- 23 May, 2023 1 commit
-
-
zlsh80826 authored
* Unfused scale+softmax if bias is present Signed-off-by:
Reese Wang <rewang@nvidia.com> * WAR a causal masking + no_bias bug and add the unittests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix the optional args (bias) sharding Signed-off-by:
Reese Wang <rewang@nvidia.com> * Disable fused attn in JAX by default, enable it with NVTE_USE_FUSED_ATTN Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add thread local for the plan cache Signed-off-by:
Reese Wang <rewang@nvidia.com> * Rename dbeta to dbias for the readability Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add scaled softmax with dropout test cases Signed-off-by:
Reese Wang <rewang@nvidia.com> * Updated NVTE_FUSED_ATTN variable name Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 09 May, 2023 1 commit
-
-
zlsh80826 authored
* Add fused attention unit tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use NVTE_* enums Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use NVTE_Mask_Type and remove FMHADescriptor Signed-off-by:
Reese Wang <rewang@nvidia.com> * Move common functions to utils Signed-off-by:
Reese Wang <rewang@nvidia.com> * Change namespace to fused_attn Signed-off-by:
Reese Wang <rewang@nvidia.com> * Move fused_attn_max_512_fwd_qkvpacked under the general APIs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add fused_attn_max_512_bwd_qkvpacked Signed-off-by:
Reese Wang <rewang@nvidia.com> * Move fused_attn_max_512_bwd_qkvpacked under the general APIs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove redundant blank line Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix a potential bug for cu_seqlen converter Signed-off-by:
Reese Wang <rewang@nvidia.com> * Reformat fused_attn_max_512 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refine the unfused attention warning message Signed-off-by:
Reese Wang <rewang@nvidia.com> * Rename to fused_attn_max_512 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove the deprecated header Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix flax import Signed-off-by:
Reese Wang <rewang@nvidia.com> * Rename to fused attn Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add attention related mask Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add attn_mask_type and attn_bias_type Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refactor jax primitive API * Merge q_cu_seqlen and kv_cu_seqlen * Remove is_causal_masking * Replace seed with rng_state * Add is_training argument Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove dsoftmax from the customcall Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add None guard for bias and dropout_rng Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add version guard Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add is_fused_attn_kernel_available() to correctly dispatch the attention impl Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix the merge conflict Signed-off-by:
Reese Wang <rewang@nvidia.com> * Adjust the code style Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add the missing blank lines Signed-off-by:
Reese Wang <rewang@nvidia.com> * Change the order of FADescriptor members Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the readability of fused_attn_max_512.cu Signed-off-by:
Reese Wang <rewang@nvidia.com> * Generalize the input dimension unpacking Signed-off-by:
Reese Wang <rewang@nvidia.com> * 16 bits fused attention requires 8.9.1 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update fused attention support matrix Signed-off-by:
Reese Wang <rewang@nvidia.com> * Handle None type when sharding Signed-off-by:
Reese Wang <rewang@nvidia.com> * Change to the padding ratio Signed-off-by:
Reese Wang <rewang@nvidia.com> * Performance optimization for non-bias cases Signed-off-by:
Reese Wang <rewang@nvidia.com> * Revert the cudnn-frontend PRIVATE keyword which was used for debugging Signed-off-by:
Reese Wang <rewang@nvidia.com> * Revert "Update fused attention support matrix" This reverts commit 4effe67d0f08f733919a329ce5ab421958740f4a. Signed-off-by:
Reese Wang <rewang@nvidia.com> * Treat b * s as total_seqs to align ragged cases Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add FP16/BF16 max_seqlen <= 512 fused attention to the support matrix Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refine test_fused_attn.py * Replace reference code with flax.linen * Remove unnecessary comments * Use AttnMaskType Signed-off-by:
Reese Wang <rewang@nvidia.com> * Unify the cuDNN compile version Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add dropout to the support matrix Signed-off-by:
Reese Wang <rewang@nvidia.com> * Slightly adjust the headers Signed-off-by:
Reese Wang <rewang@nvidia.com> * Typo fix: remove redundant either Signed-off-by:
Reese Wang <rewang@nvidia.com> * Consolidating fused attention requirements Signed-off-by:
Reese Wang <rewang@nvidia.com> * Replace cudnn_frontend::throw_if with NVTE_CHECK for the better error line report Signed-off-by:
Reese Wang <rewang@nvidia.com> * Rename to fused_attn_fp16_bf16_max_seqlen_512 for the better readability Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove CUDNN_FRONTEND_UNUSED Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add more annotations to the custom calls Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 14 Mar, 2023 1 commit
-
-
Ming-Xu Huang authored
* Updated TE/JAX docs Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding TE/JAX docs' rst files Signed-off-by:
Ming Huang <mingh@nvidia.com> * Set DType as pybind11::module_local() to avoid generic_type errors. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Updating license and exporting more modules Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adopting autoapi and removing enum_tools. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix typo Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Make jax.rst be style consistent. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fixing doc statements as the suggestion from review. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fixing doc statements as the suggestion from code review. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Update the description of Softmax Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Removed categories in catalog as PyTorch Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 09 Mar, 2023 1 commit
-
-
Jeng Bai-Cheng authored
* add transformer module , unittests and examples Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update tests/jax/test_sharding.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update transformer_engine/jax/transformer.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * remove pylint: disable=line-too-long Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove pylint: disable=too-many-func-args Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Fix the wrong broadcasting dim to dropout masks when enable transpose_bs. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Enable 2xACC for WGRAD and DGRAD by default Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * rename LayerNormMlpBlock as LayerNormMLP Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor to avoid line-too-long Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * rename amax_history_size to amax_history_len Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * align dropout mask to TE/PyTorch as default Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * enlarge atol for decoder unittests Two decoder unittests can pass in old JAX container(e.g., 23.02) but can't in latest container (devel). 1. The actual(-0.020264) and desired(-0.020386) are very close. 2. The TE kernels are not changed, the diff should come from new codegen behavior of XLA. Thus, it is a common floating-point accumulated error. Enlarge atol to avoid unittest failures. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Adding Amax History Support 1. hide amax update in custom_vjp 2. replace amax indexing with roll(using circular buffer) Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * move kernel_init to __post_init__ Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor encoder examples Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update transformer_engine/jax/fp8.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update transformer_engine/jax/fp8.py Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * remove envvar regarding 2xACC Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove unused import Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> --------- Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by:
Ming-Xu Huang <mingh@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 24 Feb, 2023 1 commit
-
-
Jeng Bai-Cheng authored
* add building workflow for jax modules Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * replace bit_cast with reinterpret_cast Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add nvtx to cmake check list Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor layernorm fwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor rmsnorm fwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor layernorm_bwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * set pytorch as default in setup.py Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * rename extension from *.cc to *.cpp cpplint cannot recognize *.cc file, so rename the extension Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor style, to align TE/PyTorch Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add pybinding, unittest and qa Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix license Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * disable c-extension-no-member and no-name-in-module Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add dataclass avoid pylint error Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update transformer_engine/__init__.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update tests/jax/test_custom_call_shape.py fix typo Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update tests/jax/test_custom_call_shape.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * add building workflow for jax modules Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * replace bit_cast with reinterpret_cast Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add nvtx to cmake check list Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor layernorm fwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor rmsnorm fwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor layernorm_bwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * set pytorch as default in setup.py Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * rename extension from *.cc to *.cpp cpplint cannot recognize *.cc file, so rename the extension Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor style, to align TE/PyTorch Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add pybinding, unittest and qa Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix license Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * disable c-extension-no-member and no-name-in-module Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add dataclass avoid pylint error Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update transformer_engine/__init__.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update tests/jax/test_custom_call_shape.py fix typo Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update tests/jax/test_custom_call_shape.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * fix conflict due to PR62 Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix c-extension-no-member and no-name-in-module 1. add transformer_engine_jax into extension-pkg-whitelist 2. convert pylintrc from CRLF to LF format Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update setup.py Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * remove pylint:disable and refactor import order Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> --------- Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-