- 13 Nov, 2025 1 commit
-
-
jberchtold-nvidia authored
* Support for checkpointing quantizations Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Add jaxpr test for quant checkpoint name Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Revert "Support for checkpointing quantizations" This reverts commit f7b784940369d0da2a77c57fa6ea744e883c5832. Signed-off-by:
JAX Toolbox <jax@nvidia.com> * Checkpoint quantizations Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * lint Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * revert other files Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * move checkpointing to VJPs Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * fix ci failure Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> --------- Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> Signed-off-by:
JAX Toolbox <jax@nvidia.com> Co-authored-by:
JAX Toolbox <jax@nvidia.com>
-
- 09 Oct, 2025 1 commit
-
-
jberchtold-nvidia authored
Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> Co-authored-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 07 Oct, 2025 1 commit
-
-
Phuong Nguyen authored
* reuse amax for current scaling Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 03 Oct, 2025 1 commit
-
-
vthumbe1503 authored
Signed-off-by:Varun Thumbe <vthumbe@nvidia.com> *Jax integration for clamped swiglu. This is the continuation of PR which added Clamped Swiglu(used in GPT OSS) support in TE along with Pytorch integration. This PR hooks up the clamped swiglu and dswiglu's nvte APIs to TE Jax.
-
- 27 Sep, 2025 1 commit
-
-
Phuong Nguyen authored
* init cgemm + unit tests * UB bootstrap with NCCL, no MPI dependency * add NVLINK-P2P check + error message * skip tests if no NVLINK available * use std::vector to store ncclComm_t * update misuse of TP warning Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 23 Sep, 2025 1 commit
-
-
Ming-Xu Huang authored
* Adding Amax Primitive and related args. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Enable local-amax for current-scaling and optionally run AR aross FSDP/TP/SP. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding doc for Amax Primitive. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix the function name conflict. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Modification as feedback suggested. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix errors from lint. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix the wrong amax-scope in the bwd. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Added more description for amax-scope Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix the wrong attribute name. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Keep dim for AmaxCalcuation. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Remove keepDim and add shardy_rule Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix shardy_rule Signed-off-by:
Ming Huang <mingh@nvidia.com> * Remove extra-collective bytes from ref_coll_count due to local amax. Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> Co-authored-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 05 Sep, 2025 1 commit
-
-
jberchtold-nvidia authored
* Custom call tests passing Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Fix test_layer.py Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Lint Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Fix comments Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Support using amax on HighPrecision tensor if it exists instead of recomputing for current scaling Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Fix shardy issue with amax being shape 1,1,1 instead of shape (1,) Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Add higher-precision VJP tests to test_distributed_layernorm_mlp Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Cast non-quantized kernels to input dtype in VJPs Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Rename HighPrecisionTensor to NoScaleTensor Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Use NoScaleTensor in pure JAX impls where it was missing Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Fix tests Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> --------- Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com>
-
- 27 Aug, 2025 1 commit
-
-
Phuong Nguyen authored
* add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 13 Aug, 2025 1 commit
-
-
Phuong Nguyen authored
* fix pspec check Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleaning Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * add docstring Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * use dict.get() Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * fix lint Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
-
- 08 Aug, 2025 1 commit
-
-
Phuong Nguyen authored
* enabled TE GEMM for all recipes Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * add warnings Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
-
- 07 Aug, 2025 1 commit
-
-
Phuong Nguyen authored
* rm batch_dim, sequence_dim, sequence_parallel_output Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * rm lhs_quantized_colwise and rhs_quantized_colwise Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * rm unnecessary transpose_batch_sequence arg from some modules Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 06 Aug, 2025 1 commit
-
-
Phuong Nguyen authored
* remove dot1_output_axes Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 24 Jul, 2025 1 commit
-
-
Alp Dener authored
[JAX] Fixing GemmPrimitive partitioning rules to handle tensor-parallelism correctly for sequence-parallel inputs (#1980) * updated GemmPrimitive partitioning rules to explicitly control all-reduce vs. reduce-scatter for sequence-parallelism Signed-off-by:
Alp Dener <adener@nvidia.com> * corrected handling of FSDP sharding for the RHS operand Signed-off-by:
Alp Dener <adener@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use correct logical axes variable to identify sequence-parallel dim in LayerNormDenseGeneral Signed-off-by:
Alp Dener <adener@nvidia.com> * fixed linting issues Signed-off-by:
Alp Dener <adener@nvidia.com> * added assert on sequence-parallel options when GemmPrimitive is disabled Signed-off-by:
Alp Dener <adener@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by:
Alp Dener <adener@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
-
- 14 Jul, 2025 1 commit
-
-
Alp Dener authored
* added XLA FFI custom op for TE/common nvte_cublas_gemm Signed-off-by:
Alp Dener <adener@nvidia.com> started GemmPrimitive, abstract done Signed-off-by:
Alp Dener <adener@nvidia.com> gemm custom op working with BF16, needs testing for FP8/MXFP8 Signed-off-by:
Alp Dener <adener@nvidia.com> converted TE GEMM API to use ScaledTensor and added os ENV flag to use TE GEMM under general gemm() call Signed-off-by:
Alp Dener <adener@nvidia.com> BF16 tests passing, FP8 tests should be passing but contracting_dims has a scoping issue Signed-off-by:
Alp Dener <adener@nvidia.com> fp8 tests passing for E4M3, getting CUBLAS_STATUS_NOT_SUPPORTED for E5M2 Signed-off-by:
Alp Dener <adener@nvidia.com> updated GEMM API to use separate LHS and RHS quantizers instead of a QuantizerSet Signed-off-by:
Alp Dener <adener@nvidia.com> new GemmPrimitive passing all Dense tests Signed-off-by:
Alp Dener <adener@nvidia.com> import cleanup and reverted code chunk movement Signed-off-by:
Alp Dener <adener@nvidia.com> removed unused .transpose() implementations from ScaledTensors Signed-off-by:
Alp Dener <adener@nvidia.com> all custom call tests passing on Hopper, GEMM-related tests cover both GemmPrimitive and native JAX impl Signed-off-by:
Alp Dener <adener@nvidia.com> removed direct calls to GemmPrimitive.enabled() from outside of cpp_extensions Signed-off-by:
Alp Dener <adener@nvidia.com> removed unused changes to ScaledTensor classes and debug prints Signed-off-by:
Alp Dener <adener@nvidia.com> * minor unit test cleanup Signed-off-by:
Alp Dener <adener@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * FP8 tests passing on Blackwell but MXFP8 outputs NaN Signed-off-by:
Alp Dener <adener@nvidia.com> * reverted dense and fuseddense changes, FP8 test passing on Hopper and Blackwell, MXFP8 has issues with E5M2 Signed-off-by:
Alp Dener <adener@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * MXFP8 issue traced to scale factor padding with NaNs instead of zeros Signed-off-by:
Alp Dener <adener@nvidia.com> * padding scale with 2^-127 instead of nans Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * fix bug on rhs_scale_inv usage Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * cleanup E8M0 type converter use it in gemm.cpp Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * segfault fixed, passing all unittests on Blackwell Signed-off-by:
Alp Dener <adener@nvidia.com> * fix for fuseddense tests Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * fix workspace alignment Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed GemmPrimitive custom partitioning to match jax.nn.scaled_matmul Signed-off-by:
Alp Dener <adener@nvidia.com> all unit tests passing on H100x8 node Signed-off-by:
Alp Dener <adener@nvidia.com> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci linting fixes Signed-off-by:
Alp Dener <adener@nvidia.com> fixed batch dimension numbers Signed-off-by:
Alp Dener <adener@nvidia.com> fixed FP8 scale sharding rule when there are no FP8 scales Signed-off-by:
Alp Dener <adener@nvidia.com> added error message for unsupported Shardy partitioner Signed-off-by:
Alp Dener <adener@nvidia.com> fixed test tolerances for FP8 cases Signed-off-by:
Alp Dener <adener@nvidia.com> fixed shardy test skip cases Signed-off-by:
Alp Dener <adener@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * moved reshape of encoder output in encoder examples to make custom partitioning rules work correctly Signed-off-by:
Alp Dener <adener@nvidia.com> * added helper functions for padding and unpadding block scales, changed GemmPrimitive to accept unpadded scales and pad them after sharding Signed-off-by:
Alp Dener <adener@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * updated shardy rules for all custom ops to decouple block scale rules from their tensors Signed-off-by:
Alp Dener <adener@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed linting errors Signed-off-by:
Alp Dener <adener@nvidia.com> * changed unit test use_jax_gemm option to be a context to preserve external custom op settings, tightened multi-GPU encoder test tolerances, changed gemm() API to use contracting_dims and batched_dims separately instead of dimension_numbers Signed-off-by:
Alp Dener <adener@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed typo in test utils Signed-off-by:
Alp Dener <adener@nvidia.com> * added sequence-first input warnings Signed-off-by:
Alp Dener <adener@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed datasets version for JAX examples Signed-off-by:
Alp Dener <adener@nvidia.com> * reverting modification to force_1x_quantization decision Signed-off-by:
Alp Dener <adener@nvidia.com> * corrected gemm function syntax in unit tests Signed-off-by:
Alp Dener <adener@nvidia.com> --------- Signed-off-by:
Alp Dener <adener@nvidia.com> Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 18 Jun, 2025 1 commit
-
-
Phuong Nguyen authored
* TensorUsage + FP8 GEMM with all layouts handling on BW Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 16 May, 2025 1 commit
-
-
jberchtold-nvidia authored
* [JAX] Update flax module param initialization to support logical partitioning axes Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Fix ffn1 intermediate result being replicated Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Lint Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Add documentation and assert when logical_axes=None Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Fix bias in LayerNormMLP flax module Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Fix layer tests to not use nn_partitioning and instead use nn.with_logical_axes Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> --------- Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com>
-
- 04 Apr, 2025 1 commit
-
-
Phuong Nguyen authored
* rename QuantizeAxis to QuantizeLayout, get_layout to get_data_layout, q_axis to q_layout * add fatten_axis option * added gated act to test encoder * sharding constraint fixes * fix padding when flattening first dim needs to be padded * update test sizes so that padding is tested * rm output sharding as it can be done in the flax module * sharding scale_inv for mxfp8 --------- Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com>
-
- 01 Apr, 2025 1 commit
-
-
Phuong Nguyen authored
* refactor + mxfp8 * added grouped gemm * rename linear to dense * added cublas init phase for groupedGemm * relax the tol of test encoder multiprocessing mxfp8 by 0.001 Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> Co-authored-by:
Hua Huang <huah@nvidia.com> Co-authored-by:
Jeremy Berchtold <jberchtold@nvidia.com>
-
- 02 Jan, 2025 1 commit
-
-
Kirthi Shankar Sivamani authored
Signed-off-by:Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 06 Nov, 2024 1 commit
-
-
Hua Huang authored
* FFI for some transpose & activation functions Signed-off-by:
Hua Huang <huah@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove comments in transformer_engine/jax/csrc/extensions/activation.cpp Co-authored-by:
Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> Signed-off-by:
Hua Huang <huangh1994@outlook.com> --------- Signed-off-by:
Hua Huang <huah@nvidia.com> Signed-off-by:
Hua Huang <huangh1994@outlook.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by:
Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
-
- 17 Jul, 2024 1 commit
-
-
Reese Wang authored
* Add enabled() to BasePrimitive * Add layernorm/rmsnorm fallback * Add cast_fp8 fallback * Add transpose/cast_transpose XLA fall back * Act_lu fallback * Add transpose fallback * Add softmax fallback * Unify the use of _cast_fp8 * Add tests for NVTE_JAX_CUSTOM_CALLS_RE --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Co-authored-by:
Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
-
- 14 Jun, 2024 1 commit
-
-
Kirthi Shankar Sivamani authored
* Apply formatting Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Apply formatting Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 13 Jun, 2024 1 commit
-
-
Phuong Nguyen authored
* Splitted cpp_extensions.py, renamed mlp.py and fused_attn.py Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * fixed import in tests Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 12 Jun, 2024 1 commit
-
-
Ming-Xu Huang authored
* Reformat FP8 Meta 1. Reformat FP8 meta to be one-set-per-tensor. 2. Remove fp8_max and scale_inv. 3. Remove unused functions in fp8.py Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix unit-tests Signed-off-by:
Ming Huang <mingh@nvidia.com> * Remove ShardingType and MajorShardingType Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix lint errors Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fixed unittests. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Rename few variables. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Add jit to update_amax_list Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fixed naming error in LayernormMLP Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fixed bugs in test_distributed_layernorm_mlp.py Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com>
-
- 13 May, 2024 1 commit
-
-
Phuong Nguyen authored
* renamed gelu to act * added relu, srelu, qgelu * fixes initialization for layernorm_fp8_mlp tests * moved activation_fp8 prim into testunit file * Moved NVTE_Activation_Enum to common/.../activation.h --------- Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com>
-
- 03 May, 2024 1 commit
-
-
Phuong Nguyen authored
* templated primitives and respective C++ functions Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * fixes for LayerNormMLP, tests in test_custom_compute all passed Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * added default arg for pybind get_workspace_size funcs Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * fixes for TestTransFormer with non-gated act tests Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * renamed gelu to act Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * improved enum implementation, avoid using magic numbers Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * Exposed C++ ActivationEnum to python side Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * Changed error messages Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * changed conditional check on input shape for dbias_cast_transpose Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * changed dtype (tol) for bias grad tests Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * fixes so that layer_norm_fp8_mlp can take bias = None Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * Set bias = None in flax modules Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 01 May, 2024 1 commit
-
-
Ming-Xu Huang authored
* Support FP8 Meta Dtype (FM32) and Align FP8 Scale Update with PyTorch. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Modify with the feedback of code review Signed-off-by:
Ming Huang <mingh@nvidia.com> * Hiding FlaxFloatMeta32 inside fp8.py Signed-off-by:
Ming Huang <mingh@nvidia.com> * Make functions to be JAX tracable objects. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Rebased with mian. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Update jax images for github workflow. Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com>
-
- 24 Apr, 2024 2 commits
-
-
Phuong Nguyen authored
* Implemented swiglu and silu Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * Renamed nvte-*silu to nvte-*swish + generalized GetDBiasDact functions Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
Phuong Nguyen authored
* combined layernorm_geglu with layernorm_gelu into fused_layernorm Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * fixes to pass all unit tests in test_custom_call_compute.py, test_layer.py, and test_praxis_layer.py Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * cleaning and formatting Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * renaming based on reviewers suggestions Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * implemented partial fused layernorm Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * geglu + bias passed tests Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * added partial fused calculation for dbias_1 Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * clean up Co-authored-by:
Alp Dener <adener@nvidia.com> Signed-off-by:
Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> Signed-off-by:
Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> Co-authored-by:
Alp Dener <adener@nvidia.com>
-
- 02 Feb, 2024 1 commit
-
-
Ming-Xu Huang authored
* Adding support of sequence parallelism Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding RoPE Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix wrong batch_logical_axes Signed-off-by:
Ming Huang <mingh@nvidia.com> * Rnaming FSDP outer env var Signed-off-by:
Ming Huang <mingh@nvidia.com> * Poring RoPE to Praxis layers. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Porting GeLU + [FP8 Cast]. Signed-off-by:
Ming Huang <mingh@nvidia.com> * WAR to make XLA successfully match FP8 GEMM on FFN1 with GeLU. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Allowing arbitrary dimension of NVShape for the workspace allocation Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding checkpoint_name to fused functions of mlp.py to get better perf with nn.scan. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Modify with review feedback. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix bugs Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix typo. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fixed for lint Signed-off-by:
Ming Huang <mingh@nvidia.com> * Follow review feedback to modify code. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix typo. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Port SP to Praxis Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix an issue when enabling both GQA and RoPE. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Update docs Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com>
-
- 29 Jan, 2024 1 commit
-
-
Alp Dener authored
* Removed cudaMalloc/WorkspaceManager in JAX csrc. JAX custom ops now request buffers from XLA for their workspace tensors. Signed-off-by:
Alp Dener <adener@nvidia.com> * removed unused GEMM C++ API in TE-JAX Signed-off-by:
Alp Dener <adener@nvidia.com> * fixed typo in layernorm_geglu_fp8_mlp and removed unnecessary shape reductions in primitives Signed-off-by:
Alp Dener <adener@nvidia.com> * fixed import order for linting Signed-off-by:
Alp Dener <adener@nvidia.com> * fixed custom op errors due to incorrect static arg nums in JAX jit Signed-off-by:
Alp Dener <adener@nvidia.com> * shifted cudnnSetStream further down the kernel to avoid error when executing dummy kernel call with nullptr stream Signed-off-by:
Alp Dener <adener@nvidia.com> * fixed linting errors for blank lines Signed-off-by:
Alp Dener <adener@nvidia.com> --------- Signed-off-by:
Alp Dener <adener@nvidia.com>
-
- 12 Jan, 2024 1 commit
-
-
Ming-Xu Huang authored
* Adding Cast custom call Signed-off-by:
Ming Huang <mingh@nvidia.com> * Applying cast to the kernel of layernorm_fp8_dot Signed-off-by:
Ming Huang <mingh@nvidia.com> * Applying native cast to the kernel of fp8_dot. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Apply Cast and native cast to layernorm_geglu_fp8_dot Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix the bug to enable layernorm_geglu_fp8_dot in LayernormMlp Signed-off-by:
Ming Huang <mingh@nvidia.com> * Modifiied code with the review feedback. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding 2xACC control to FP8 GEMMs. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Set precision as an static arg Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com>
-
- 03 Jan, 2024 1 commit
-
-
Przemyslaw Tredak authored
Signed-off-by:Przemek Tredak <ptredak@nvidia.com>
-
- 14 Nov, 2023 1 commit
-
-
Ming-Xu Huang authored
* Refactor sharding.py for the further custom_partitioning migration Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating both FWD and BWD of LayerNorm/RMSNorm from xmap to custom_partitioning. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating both FWD and BWD of all kinds of softmax from xmap to custom_partitioning. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix the wrong order of parameters to LN/RMSN bwd in ln_mlp_fp8. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * WAR to LN/RMSN_fp8 before migrating to CP. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix the wrong order of parameters of bwd of LN/RMSN_fp8. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Following review feedback to modify Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Force the hidden dim in Norm ops to no sharding and add warning msg. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Reuse fwd_rule in VJP functions Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating both FWD and BWD of self-fused-attn from xmap to custom_partitioning. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating both FWD and BWD of cross-fused-attn from xmap to custom_partitioning. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * add gelu and dgelu. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Reuse fwd_rule in VJP functions for attentions Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Apply native FP8 Dtypes to fp8.py Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating cast_and_transpose from xmap to custom_partitioning Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating transpose from xmap to custom_partitioning Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Apply XLA pattern match to perform FP8 GEMM. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * migrate layernorm_fp8 to custom_partitioning. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Unify code style Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Extend supported of Transpose with FP8 Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Implementing layernorm_fp8_dot based on migrated custom calls. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Renaming variables and publish NVTE_FP8_COLLECTION_NAME Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Replace Q/DQ custom calls with native XLA implementations Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * migrate gelu_fp to custom_partitioning. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Miner fix Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Support custom calls with mutli-dims Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Support gerneral dot indices in _fp8_dot_impl Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Implementing layernrom_geglu_fp8_mlp Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Remove GEMM custom calls Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Remove xmap related code Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix typo and add query-function to FP8MetaPackage Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix some bugs of custom calls Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix CT's bugs Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Update UTs/eaxmaples to adapt to the API changes. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Unify kernel initilization in MLP. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Modifing with code review's feedback Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Update README and Add deprecating warning to *ShardingType Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Canonicalize the dtype Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding assertion for non-supported batch dims. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding doc/examples to _multidim_transpose Signed-off-by:
Ming Huang <mingh@nvidia.com> * Set FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Set FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Apply dtype-based rtol/atol to UTs Signed-off-by:
Ming Huang <mingh@nvidia.com> * Deprecate QKV_INTERLEAVED enum Signed-off-by:
Ming Huang <mingh@nvidia.com> * Skip test_distributed_custom_ops.py Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix the wrong sharding of bias in SelfAttn Signed-off-by:
Ming Huang <mingh@nvidia.com> * WAR to fix the wrong cu_seqlen of MHA when DP/FSDP enabled Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding distributed ops unit-tests Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding license to test_distributed_* Signed-off-by:
Ming Huang <mingh@nvidia.com> * Follow review feedback to modify Signed-off-by:
Ming Huang <mingh@nvidia.com> * Use total bytes involved in collective ops as criteria. Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> Co-authored-by:
Donglin Yang <dongliny@nvidia.com>
-
- 09 Aug, 2023 1 commit
-
-
Ming-Xu Huang authored
* Initially commit for FSDP Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding support to fsdp xmap sharding Signed-off-by:
Ming Huang <mingh@nvidia.com> * Specify WeightHParamsCollection of fp8 meta. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Support partial FP8 custom calls with FSDP. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding amax reduction on the fsdp mesh dim. Signed-off-by:
Ming Huang <mingh@nvidia.com> * clean code Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix the wrong batch axis in logic_axis_rules and add sharding_constraint to BMM1 Signed-off-by:
Ming Huang <mingh@nvidia.com> * Support FSDP in fMHA. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix missing all-reduce of wgrads along FSDP axis. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Change default value of fsdp_axis_name to for aligning with others Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix RuntimeError: with_sharding_constraint requires a non-empty Signed-off-by:
Ming Huang <mingh@nvidia.com> * Slightly changes (review feedback) Signed-off-by:
Ming Huang <mingh@nvidia.com> * Removed unnecessary comments Signed-off-by:
Ming Huang <mingh@nvidia.com> * Mergeing input_dp_dim into weight_fsdp_dim_map Signed-off-by:
Ming Huang <mingh@nvidia.com> * Update transformer_engine/jax/sharding.py Signed-off-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com>
-
- 13 Apr, 2023 1 commit
-
-
zlsh80826 authored
* Add zero_center_gamma/functional pass Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add zero_centered_gamma for fp8_ln_mlp Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add zero_centered_gamma to modules Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add zero_centered_gamma to TransformerLayer Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refactored code style for improved readability and consistency Signed-off-by:
Reese Wang <rewang@nvidia.com> * Docs enhancement for zero_centered_gamma Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add escape for line break and remove some bad if conditions Signed-off-by:
Reese Wang <rewang@nvidia.com> * Revise scale_init docs Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com>
-
- 10 Mar, 2023 1 commit
-
-
Ming-Xu Huang authored
Signed-off-by:Ming-Xu Huang <mingh@nvidia.com>
-
- 09 Mar, 2023 1 commit
-
-
Jeng Bai-Cheng authored
* add transformer module , unittests and examples Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update tests/jax/test_sharding.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update transformer_engine/jax/transformer.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * remove pylint: disable=line-too-long Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove pylint: disable=too-many-func-args Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Fix the wrong broadcasting dim to dropout masks when enable transpose_bs. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Enable 2xACC for WGRAD and DGRAD by default Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * rename LayerNormMlpBlock as LayerNormMLP Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor to avoid line-too-long Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * rename amax_history_size to amax_history_len Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * align dropout mask to TE/PyTorch as default Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * enlarge atol for decoder unittests Two decoder unittests can pass in old JAX container(e.g., 23.02) but can't in latest container (devel). 1. The actual(-0.020264) and desired(-0.020386) are very close. 2. The TE kernels are not changed, the diff should come from new codegen behavior of XLA. Thus, it is a common floating-point accumulated error. Enlarge atol to avoid unittest failures. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Adding Amax History Support 1. hide amax update in custom_vjp 2. replace amax indexing with roll(using circular buffer) Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * move kernel_init to __post_init__ Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor encoder examples Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update transformer_engine/jax/fp8.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update transformer_engine/jax/fp8.py Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * remove envvar regarding 2xACC Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove unused import Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> --------- Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by:
Ming-Xu Huang <mingh@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 24 Feb, 2023 1 commit
-
-
Jeng Bai-Cheng authored
* add building workflow for jax modules Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * replace bit_cast with reinterpret_cast Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add nvtx to cmake check list Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor layernorm fwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor rmsnorm fwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor layernorm_bwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * set pytorch as default in setup.py Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * rename extension from *.cc to *.cpp cpplint cannot recognize *.cc file, so rename the extension Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor style, to align TE/PyTorch Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add pybinding, unittest and qa Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix license Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * disable c-extension-no-member and no-name-in-module Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add dataclass avoid pylint error Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update transformer_engine/__init__.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update tests/jax/test_custom_call_shape.py fix typo Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update tests/jax/test_custom_call_shape.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * add building workflow for jax modules Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * replace bit_cast with reinterpret_cast Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add nvtx to cmake check list Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor layernorm fwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor rmsnorm fwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor layernorm_bwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * set pytorch as default in setup.py Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * rename extension from *.cc to *.cpp cpplint cannot recognize *.cc file, so rename the extension Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor style, to align TE/PyTorch Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add pybinding, unittest and qa Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix license Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * disable c-extension-no-member and no-name-in-module Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add dataclass avoid pylint error Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update transformer_engine/__init__.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update tests/jax/test_custom_call_shape.py fix typo Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update tests/jax/test_custom_call_shape.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * fix conflict due to PR62 Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix c-extension-no-member and no-name-in-module 1. add transformer_engine_jax into extension-pkg-whitelist 2. convert pylintrc from CRLF to LF format Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update setup.py Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * remove pylint:disable and refactor import order Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> --------- Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-