- 27 Sep, 2025 1 commit
-
-
Phuong Nguyen authored
* init cgemm + unit tests * UB bootstrap with NCCL, no MPI dependency * add NVLINK-P2P check + error message * skip tests if no NVLINK available * use std::vector to store ncclComm_t * update misuse of TP warning Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 14 Jul, 2025 1 commit
-
-
Alp Dener authored
* added XLA FFI custom op for TE/common nvte_cublas_gemm Signed-off-by:
Alp Dener <adener@nvidia.com> started GemmPrimitive, abstract done Signed-off-by:
Alp Dener <adener@nvidia.com> gemm custom op working with BF16, needs testing for FP8/MXFP8 Signed-off-by:
Alp Dener <adener@nvidia.com> converted TE GEMM API to use ScaledTensor and added os ENV flag to use TE GEMM under general gemm() call Signed-off-by:
Alp Dener <adener@nvidia.com> BF16 tests passing, FP8 tests should be passing but contracting_dims has a scoping issue Signed-off-by:
Alp Dener <adener@nvidia.com> fp8 tests passing for E4M3, getting CUBLAS_STATUS_NOT_SUPPORTED for E5M2 Signed-off-by:
Alp Dener <adener@nvidia.com> updated GEMM API to use separate LHS and RHS quantizers instead of a QuantizerSet Signed-off-by:
Alp Dener <adener@nvidia.com> new GemmPrimitive passing all Dense tests Signed-off-by:
Alp Dener <adener@nvidia.com> import cleanup and reverted code chunk movement Signed-off-by:
Alp Dener <adener@nvidia.com> removed unused .transpose() implementations from ScaledTensors Signed-off-by:
Alp Dener <adener@nvidia.com> all custom call tests passing on Hopper, GEMM-related tests cover both GemmPrimitive and native JAX impl Signed-off-by:
Alp Dener <adener@nvidia.com> removed direct calls to GemmPrimitive.enabled() from outside of cpp_extensions Signed-off-by:
Alp Dener <adener@nvidia.com> removed unused changes to ScaledTensor classes and debug prints Signed-off-by:
Alp Dener <adener@nvidia.com> * minor unit test cleanup Signed-off-by:
Alp Dener <adener@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * FP8 tests passing on Blackwell but MXFP8 outputs NaN Signed-off-by:
Alp Dener <adener@nvidia.com> * reverted dense and fuseddense changes, FP8 test passing on Hopper and Blackwell, MXFP8 has issues with E5M2 Signed-off-by:
Alp Dener <adener@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * MXFP8 issue traced to scale factor padding with NaNs instead of zeros Signed-off-by:
Alp Dener <adener@nvidia.com> * padding scale with 2^-127 instead of nans Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * fix bug on rhs_scale_inv usage Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * cleanup E8M0 type converter use it in gemm.cpp Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * segfault fixed, passing all unittests on Blackwell Signed-off-by:
Alp Dener <adener@nvidia.com> * fix for fuseddense tests Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * fix workspace alignment Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed GemmPrimitive custom partitioning to match jax.nn.scaled_matmul Signed-off-by:
Alp Dener <adener@nvidia.com> all unit tests passing on H100x8 node Signed-off-by:
Alp Dener <adener@nvidia.com> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci linting fixes Signed-off-by:
Alp Dener <adener@nvidia.com> fixed batch dimension numbers Signed-off-by:
Alp Dener <adener@nvidia.com> fixed FP8 scale sharding rule when there are no FP8 scales Signed-off-by:
Alp Dener <adener@nvidia.com> added error message for unsupported Shardy partitioner Signed-off-by:
Alp Dener <adener@nvidia.com> fixed test tolerances for FP8 cases Signed-off-by:
Alp Dener <adener@nvidia.com> fixed shardy test skip cases Signed-off-by:
Alp Dener <adener@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * moved reshape of encoder output in encoder examples to make custom partitioning rules work correctly Signed-off-by:
Alp Dener <adener@nvidia.com> * added helper functions for padding and unpadding block scales, changed GemmPrimitive to accept unpadded scales and pad them after sharding Signed-off-by:
Alp Dener <adener@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * updated shardy rules for all custom ops to decouple block scale rules from their tensors Signed-off-by:
Alp Dener <adener@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed linting errors Signed-off-by:
Alp Dener <adener@nvidia.com> * changed unit test use_jax_gemm option to be a context to preserve external custom op settings, tightened multi-GPU encoder test tolerances, changed gemm() API to use contracting_dims and batched_dims separately instead of dimension_numbers Signed-off-by:
Alp Dener <adener@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed typo in test utils Signed-off-by:
Alp Dener <adener@nvidia.com> * added sequence-first input warnings Signed-off-by:
Alp Dener <adener@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed datasets version for JAX examples Signed-off-by:
Alp Dener <adener@nvidia.com> * reverting modification to force_1x_quantization decision Signed-off-by:
Alp Dener <adener@nvidia.com> * corrected gemm function syntax in unit tests Signed-off-by:
Alp Dener <adener@nvidia.com> --------- Signed-off-by:
Alp Dener <adener@nvidia.com> Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 12 Jun, 2025 1 commit
-
-
Phuong Nguyen authored
* fixes for jittable grouped_quantize * fixes for jittable grouped_gemm * fix contracting_dim for wgrad gemm * exclude jitted grouped_gemm from the unit test as it does not work cudaGraph --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
-
- 06 Jun, 2025 1 commit
-
-
Phuong Nguyen authored
* refactor the multi_stream utils + implement nvte_multi_tensor_quantize in TE/Common * implement GroupedQuantizer and grouped_quantize in jaxx * fix logical_axes_names for transpose tensor in ScaledTensor Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> Co-authored-by:
Hua Huang <huah@nvidia.com> Co-authored-by:
Ming Huang <mingh@nvidia.com>
-
- 05 Jun, 2025 1 commit
-
-
Kirthi Shankar Sivamani authored
* Fix NVTE_FRAMEWORK=all Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * fix Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Workflow tests and fixes Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix jax install Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Update dep Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Add numpy Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Add dep Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 22 Apr, 2025 1 commit
-
-
jberchtold-nvidia authored
* [JAX-Q] Single GPU current scaling for JAX Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Fix scale check dtype for MXFP8 scales affecting tests using assert_bitwise_scaled_tensors Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Address comments Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Remove cast to fp32 for norm primitives now that zero-centered gamma dtype issue is fixed Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Fix lint issue Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Remove unnecessary cast to fp32 Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Lint Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> --------- Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com>
-
- 09 Apr, 2025 1 commit
-
-
Phuong Nguyen authored
* scaling enum abstract * rm NVTE_ from ScalingMode names * rework scaling mode enum in grouped gemm * fix norm sharding --------- Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com>
-
- 04 Apr, 2025 1 commit
-
-
Phuong Nguyen authored
* rename QuantizeAxis to QuantizeLayout, get_layout to get_data_layout, q_axis to q_layout * add fatten_axis option * added gated act to test encoder * sharding constraint fixes * fix padding when flattening first dim needs to be padded * update test sizes so that padding is tested * rm output sharding as it can be done in the flax module * sharding scale_inv for mxfp8 --------- Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com>
-
- 01 Apr, 2025 1 commit
-
-
Phuong Nguyen authored
* refactor + mxfp8 * added grouped gemm * rename linear to dense * added cublas init phase for groupedGemm * relax the tol of test encoder multiprocessing mxfp8 by 0.001 Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> Co-authored-by:
Hua Huang <huah@nvidia.com> Co-authored-by:
Jeremy Berchtold <jberchtold@nvidia.com>
-
- 02 Jan, 2025 1 commit
-
-
Kirthi Shankar Sivamani authored
Signed-off-by:Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 14 Dec, 2024 1 commit
-
-
Phuong Nguyen authored
* softmax custom calls with correct encapsulates * rm jax deprecated features --------- Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com>
-
- 12 Dec, 2024 1 commit
-
-
Phuong Nguyen authored
* fix ctx.aval_out indexing for workspace * add cudnn init to prepare phase of norm custom calls * add thread_local for norm registry instance --------- Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com>
-
- 12 Nov, 2024 1 commit
-
-
Hua Huang authored
* FFI for all softmax functions Signed-off-by:
Hua Huang <huah@nvidia.com> * FFI for FusedAttnBackward and Dequantize FusedAttnBackward passed all testes in test_fused_attn.py. Dequantize is not used currently; finish it for completeness. Signed-off-by:
Hua Huang <huah@nvidia.com> * Fix FusedAttnBackward FFI pybind & simplify Signed-off-by:
Hua Huang <huah@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert changes to tests/jax/test_fused_attn.py Signed-off-by:
Hua Huang <huah@nvidia.com> --------- Signed-off-by:
Hua Huang <huah@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by:
Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
-
- 07 Nov, 2024 1 commit
-
-
Phuong Nguyen authored
* added prepare phase for the FusedAttnForwardFFI * enabled FusedAttnForwardFFI by default * moved prepare phase into pybind --------- Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com>
-
- 06 Nov, 2024 1 commit
-
-
Hua Huang authored
* FFI for some transpose & activation functions Signed-off-by:
Hua Huang <huah@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove comments in transformer_engine/jax/csrc/extensions/activation.cpp Co-authored-by:
Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> Signed-off-by:
Hua Huang <huangh1994@outlook.com> --------- Signed-off-by:
Hua Huang <huah@nvidia.com> Signed-off-by:
Hua Huang <huangh1994@outlook.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by:
Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
-
- 04 Nov, 2024 1 commit
-
-
Hua Huang authored
* Add LayerNormForwardFFI(); add FFI calls in Python Signed-off-by:
Hua Huang <huah@nvidia.com> * Add FFI for RMS norm, all tests passed Signed-off-by:
Hua Huang <huah@nvidia.com> * Simplify layer & RMS norm FFI calls Signed-off-by:
Hua Huang <huah@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Simplify tensor size calculations Signed-off-by:
Hua Huang <huah@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by:
Hua Huang <huah@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by:
Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
-
- 24 Oct, 2024 1 commit
-
-
Hua Huang authored
[JAX] XLA Custom Calls with FFI for FusedAttnFwd, Quantize, Transpose, ActLuFP8, LayerNormForwardFP8FFI, and LayerNormBackwardFFI (#1263) * Add TransposeFFI, test passed Signed-off-by:
Hua Huang <huah@nvidia.com> * Add ActLuFP8FFI; fix TransposeFFI Signed-off-by:
Hua Huang <huah@nvidia.com> * Add QuantizeFFI Signed-off-by:
Hua Huang <huah@nvidia.com> * Add FusedAttnForwardFFI and some unit tests Signed-off-by:
Hua Huang <huah@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Minor fix Signed-off-by:
Hua Huang <huah@nvidia.com> * Add LayerNormForwardFP8FFI & LayerNormBackwardFFI Signed-off-by:
Hua Huang <huah@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revise FusedAttnForwardFFI() Signed-off-by:
Hua Huang <huah@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add FFI_CudaGraph_Traits All tests passed, ready for merge Signed-off-by:
Hua Huang <huah@nvidia.com> * Bug fix for FFI data type mismatch Also add a safeguard on the entrance to FFI function Signed-off-by:
Hua Huang <huah@nvidia.com> --------- Signed-off-by:
Hua Huang <huah@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
-
- 17 Sep, 2024 1 commit
-
-
Michael Goldfarb authored
Implementation of context parallel fused attention using all-gather. Signed-off-by:Michael Goldfarb <mgoldfarb@nvidia.com>
-
- 14 Aug, 2024 1 commit
-
-
Phuong Nguyen authored
* implemented custom call with ffi in csrc * moved headers of misc to misc.h, add ffi.h * ActLu and DActLu lowering with ffi_lowering * CastTranspose with ffi_lowering * enabled cudaGraph * added 4d input test case to TestActivationLu * added operand_output_aliases for CastTranspose * added env var NVTE_JAX_WITH_FFI, default value = 1 * replace casting ActivationEnum by taking its value --------- Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com>
-
- 06 Aug, 2024 1 commit
-
-
Reese Wang authored
* Support actlen = 0 after cuDNN 9.3.0 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add runtime_segment < max_segment tests Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 25 Jul, 2024 1 commit
-
-
Kirthi Shankar Sivamani authored
* Specify python version Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Add classifiers for python Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Add utils to build wheels Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * make wheel scripts Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Add aarch Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fixes Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix paddle wheel Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * PaddlePaddle only builds for x86 Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Add optional fwk deps Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Python3.8; catch install error Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * [wip] cudnn9 compile with paddle support Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * [wip] dont link cudnn Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * dlopen cudnn Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * fixes Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * fix Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fixes Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * fix Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * dynamically load nvrtc Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Lint Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * remove residual packages; exclude stub from nvrtc .so search Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Exclude builtins from nvrtc .so search Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * properly include files for sdist Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * paddle wheel tie to python version Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix paddle build from src [wip] Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix workflow paddle build Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix paddle Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix paddle Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * fix lint from pr986 Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * fix Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Add sanity wheel test Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Add sanity import to wheel test Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * remove upper limit on paddlepaddle version Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Remove unused imports Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Remove pybind11 dependency Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix cpp tests Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Search .sos in cuda home Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * CLeanup, remove residual code Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
-
- 03 Jul, 2024 1 commit
-
-
Reese Wang authored
* Integrate experimental ragged offset Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use per sequence based offsets Signed-off-by:
Reese Wang <rewang@nvidia.com> * Format Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove v/o_seq_offsets Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add FP16 sanity tests and remove forward tests from the automatically run tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance input checks Signed-off-by:
Reese Wang <rewang@nvidia.com> * Separate fused attn to 2 differnt APIs and add the docs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add experimental to the docs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix lint Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add runtime segments check Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove finished TODO Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 14 Jun, 2024 1 commit
-
-
Kirthi Shankar Sivamani authored
* Apply formatting Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Apply formatting Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 08 Jun, 2024 1 commit
-
-
Phuong Nguyen authored
* categorized `csrc/modules.cpp` Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * adapted the build tool Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 06 Jun, 2024 1 commit
-
-
Kirthi Shankar Sivamani authored
Cleanup Signed-off-by:Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 13 May, 2024 1 commit
-
-
Phuong Nguyen authored
* renamed gelu to act * added relu, srelu, qgelu * fixes initialization for layernorm_fp8_mlp tests * moved activation_fp8 prim into testunit file * Moved NVTE_Activation_Enum to common/.../activation.h --------- Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com>
-
- 03 May, 2024 1 commit
-
-
Phuong Nguyen authored
* templated primitives and respective C++ functions Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * fixes for LayerNormMLP, tests in test_custom_compute all passed Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * added default arg for pybind get_workspace_size funcs Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * fixes for TestTransFormer with non-gated act tests Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * renamed gelu to act Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * improved enum implementation, avoid using magic numbers Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * Exposed C++ ActivationEnum to python side Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * Changed error messages Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * changed conditional check on input shape for dbias_cast_transpose Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * changed dtype (tol) for bias grad tests Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * fixes so that layer_norm_fp8_mlp can take bias = None Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * Set bias = None in flax modules Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 24 Apr, 2024 2 commits
-
-
Phuong Nguyen authored
* Implemented swiglu and silu Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * Renamed nvte-*silu to nvte-*swish + generalized GetDBiasDact functions Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
Phuong Nguyen authored
* combined layernorm_geglu with layernorm_gelu into fused_layernorm Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * fixes to pass all unit tests in test_custom_call_compute.py, test_layer.py, and test_praxis_layer.py Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * cleaning and formatting Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * renaming based on reviewers suggestions Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * implemented partial fused layernorm Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * geglu + bias passed tests Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * added partial fused calculation for dbias_1 Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * clean up Co-authored-by:
Alp Dener <adener@nvidia.com> Signed-off-by:
Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> Signed-off-by:
Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> Co-authored-by:
Alp Dener <adener@nvidia.com>
-
- 22 Mar, 2024 1 commit
-
-
Reese Wang authored
* Remove unused headers Signed-off-by:
Reese Wang <rewang@nvidia.com> * Unify the fused attn workspace size cpp code Signed-off-by:
Reese Wang <rewang@nvidia.com> * Reduce the skipped cases Signed-off-by:
Reese Wang <rewang@nvidia.com> * Rename self/cross attention to qkvpacked/kvpacked Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update attention mask docs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refine the attn mask implementations Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 22 Feb, 2024 1 commit
-
-
Reese Wang authored
* Refine MHA API Signed-off-by:
Reese Wang <rewang@nvidia.com> * Reuse func from the flax Signed-off-by:
Reese Wang <rewang@nvidia.com> * DPA draft Signed-off-by:
Reese Wang <rewang@nvidia.com> * qkv packed draft Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix test_layer with fused attn Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add attn_bias_type and enhance a few code flow Signed-off-by:
Reese Wang <rewang@nvidia.com> * Move scale_factor from __call__ to init Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the docs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add DPA public API and tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refine docs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refine docs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix conflict Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add qkv separate fused attn Signed-off-by:
Reese Wang <rewang@nvidia.com> * Apply BSHD_BSHD_BSHD format Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove debug log Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add fused attention layer tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add NVTE_FUSED_ATTN docs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fine-grained fused attn settings Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove the default value of num_attetnion_head and head_dim Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add teardown for fused attn env Signed-off-by:
Reese Wang <rewang@nvidia.com> * Unify the Optional notation Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix Pre/Post scale bias comments Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add no_mask tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add checkpoint_name for fused attn Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix the fused attn batcher Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 02 Feb, 2024 1 commit
-
-
Ming-Xu Huang authored
* Adding support of sequence parallelism Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding RoPE Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix wrong batch_logical_axes Signed-off-by:
Ming Huang <mingh@nvidia.com> * Rnaming FSDP outer env var Signed-off-by:
Ming Huang <mingh@nvidia.com> * Poring RoPE to Praxis layers. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Porting GeLU + [FP8 Cast]. Signed-off-by:
Ming Huang <mingh@nvidia.com> * WAR to make XLA successfully match FP8 GEMM on FFN1 with GeLU. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Allowing arbitrary dimension of NVShape for the workspace allocation Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding checkpoint_name to fused functions of mlp.py to get better perf with nn.scan. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Modify with review feedback. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix bugs Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix typo. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fixed for lint Signed-off-by:
Ming Huang <mingh@nvidia.com> * Follow review feedback to modify code. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix typo. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Port SP to Praxis Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix an issue when enabling both GQA and RoPE. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Update docs Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com>
-
- 29 Jan, 2024 1 commit
-
-
Alp Dener authored
* Removed cudaMalloc/WorkspaceManager in JAX csrc. JAX custom ops now request buffers from XLA for their workspace tensors. Signed-off-by:
Alp Dener <adener@nvidia.com> * removed unused GEMM C++ API in TE-JAX Signed-off-by:
Alp Dener <adener@nvidia.com> * fixed typo in layernorm_geglu_fp8_mlp and removed unnecessary shape reductions in primitives Signed-off-by:
Alp Dener <adener@nvidia.com> * fixed import order for linting Signed-off-by:
Alp Dener <adener@nvidia.com> * fixed custom op errors due to incorrect static arg nums in JAX jit Signed-off-by:
Alp Dener <adener@nvidia.com> * shifted cudnnSetStream further down the kernel to avoid error when executing dummy kernel call with nullptr stream Signed-off-by:
Alp Dener <adener@nvidia.com> * fixed linting errors for blank lines Signed-off-by:
Alp Dener <adener@nvidia.com> --------- Signed-off-by:
Alp Dener <adener@nvidia.com>
-
- 03 Jan, 2024 1 commit
-
-
Przemyslaw Tredak authored
Signed-off-by:Przemek Tredak <ptredak@nvidia.com>
-
- 13 Nov, 2023 1 commit
-
-
zlsh80826 authored
[C/JAX] Support more mask types for the arbitrary seqlen kernels and minor changes of JAX bias (#469) * Move bias to float32 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enable varlen Signed-off-by:
Reese Wang <rewang@nvidia.com> * Increase neg infinity abs values Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enable varlen tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove unnecessary code Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix lint Signed-off-by:
Reese Wang <rewang@nvidia.com> * Support variable sequence length after cuDNN 8.9.6 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use unique_ptr instead of shared_ptr Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add a new mask type: PADDING_CAUSAL_MASK Signed-off-by:
Reese Wang <rewang@nvidia.com> * Support flash padding mask after 8.9.6 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the Max512 handling for causal masking and add the related tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update the fused attn support lists Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove padding_aware from the caching Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix libtransformer.so issue Signed-off-by:
Reese Wang <rewang@nvidia.com> * Reduce the pad ratio tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix a bug with cuDNN 8.9.5 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Release backend resource after the module level unit test Signed-off-by:
Reese Wang <rewang@nvidia.com> * Clean the jax live arrays before running the unit tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix too-few-public-methods lint Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 08 Nov, 2023 1 commit
-
-
zlsh80826 authored
* Deprecate QKV_INTERLEAVED use in JAX Signed-off-by:
Reese Wang <rewang@nvidia.com> * Deprecate QKV_INTERLEAVED use in Paddle Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance qkv enum mappings Signed-off-by:
rewang <rewang@nvidia.com> * Fix LD_LIBRARY_PATH issue Signed-off-by:
rewang <rewang@nvidia.com> * Arbitrary seqlen kernels only support self attention currently Signed-off-by:
rewang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Signed-off-by:
rewang <rewang@nvidia.com>
-
- 25 Aug, 2023 1 commit
-
-
zlsh80826 authored
* Fused attention kernel only supports sm80 and sm90 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update transformer_engine/jax/csrc/modules.cpp Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * arbitary fused kernel supports sm86/sm89 after 8.9.3 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Skip sm70 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Forward is_fused_attn_kernel_available to cpp backend Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove cpp is_fused_attn_available API Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com>
-
- 07 Aug, 2023 1 commit
-
-
zlsh80826 authored
* Fix flash attention dropout probability with inference Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add output as the fused attention ctx tensor Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add rng_state as the fused attention ctx tensors Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add flash attention supported lengths to the fused attention Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refactor attention primitive to reuse abstract shaped array Signed-off-by:
Reese Wang <rewang@nvidia.com> * Detect backend type to allocate appropriate ctx size Signed-off-by:
Reese Wang <rewang@nvidia.com> * Skip dropout correctness instead of return success Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use cudaMemsetAsync and enhance the error handling Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add flash attention kernel elts_per_thread update Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove redundant max 512 suffix Signed-off-by:
Reese Wang <rewang@nvidia.com> * Keep only DType and remove NVTEDType from python Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix a float32_attention_logits bugs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Re-calculate workspace size for self attention Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance bias/dbias shape guard Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the seed/rng_state checker Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use jax.core.ShapedArray as jax.abstract_arrays is deprecated Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the unittest docs Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 12 May, 2023 1 commit
-
-
Jeng Bai-Cheng authored
bugfix for softmax lowering Signed-off-by:Ryan Jeng <rjeng@nvidia.com>
-
- 09 May, 2023 1 commit
-
-
Jeng Bai-Cheng authored
* add mp example Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * update doc-string Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * better FP8 checker Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * update readme Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * replace te.* with te.flax* to remove deprecated warning Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove nouse os.environ Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove nouse Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix typo Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/README.md Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/README.md Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/README.md Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/README.md Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/test_multiprocessing_encoder.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove cuda-python Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * adjust readme Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/README.md Signed-off-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix cpp lint fix issue of "Could not find a newline character at the end of the file." Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix AssertionError: 1 GPU per process Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * replace tfds with datasets The Flax application crash if it use TensorFlow Dataset (tfds) in NVIDIA JAX container. The tfds is very useful for downloading well-knwon dataset (e.g., MNIST, GLUE) and commonly used by TF/JAX community. However, it seems like that it is NOT compatible with NVIDIA TensorFlow in NVIDIA JAX container and somehow affects JAX. It triggers random errors at JAX initialization depending on different versions, and make CI unstable. Thus, this commit replaces tfds with "huggingface datasets" to download needed datasets. See "nvbugs 4039266" for more details. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix input sharding Unlike SPMD mode, in multiprocessing mode, the input tensor must be sharded manually. Using DP=4, TP=2 as an example, the device mesh looks like: mesh.device_ids = [[0, 1], [2, 3], [4, 5], [6, 7]] Assume that the process ID is mapped to GPU ID. The process 0 and process 1 are grouped for model parallelism, process 2 and process 3 are grouped together too, and so on. The process 0 and process 1 need to share the same micro-batch in the training step, process 0 and process 2, 4, and 6 have different micro-batch. Thus, `shard_array_wrapper` partitions inputs to 4 parts (and setup needed arguments for jax.make_array_from_single_device_arrays). The process 0 and process 1 take the first quarter, process 2 and process 3 take the second quarter, and so on. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor UT for multiprocess example Use Python `multiprocessing` to test the multiprocessing example, if the system has multiple GPU. 1 GPU per process. Because `jax.distributed.initialize` must be called before any other JAX or Flax API, GPU info cannot be queried by calling jax.local_devices() in TestEncoder. Thus, `unittest_query_gpu()` forks another process to query number of GPUs and FP8 capability. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove nouse arg `--num-gpu` JAX doesn't have an API to setup number of GPU used in SPMD mode. The only way is to use `CUDA_VISIBLE_DEVICES` for now. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix typo Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix ut Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * simplify the mask setting Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * increase batch-size for multigpu example The batch-size 64 is too small to be partitioned for 8xH100. If batch-size is 64, the GEMM shape is 256x8192x8 per GPU. The 8 is too small for FP8 GEMM kernel, and cuBLASLt will throw "Failed to query heuristics". Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix downloading mnist error To download MNIST via `huggingface datasets`, it requires Pillow. Otherwise, it throws `An error occurred while generating the dataset` Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> --------- Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com>
-