- 14 Jun, 2024 1 commit
-
-
Kirthi Shankar Sivamani authored
* Apply formatting Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Apply formatting Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 13 Jun, 2024 1 commit
-
-
Phuong Nguyen authored
* Splitted cpp_extensions.py, renamed mlp.py and fused_attn.py Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * fixed import in tests Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 12 Jun, 2024 1 commit
-
-
Ming-Xu Huang authored
* Reformat FP8 Meta 1. Reformat FP8 meta to be one-set-per-tensor. 2. Remove fp8_max and scale_inv. 3. Remove unused functions in fp8.py Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix unit-tests Signed-off-by:
Ming Huang <mingh@nvidia.com> * Remove ShardingType and MajorShardingType Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix lint errors Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fixed unittests. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Rename few variables. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Add jit to update_amax_list Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fixed naming error in LayernormMLP Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fixed bugs in test_distributed_layernorm_mlp.py Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com>
-
- 10 Jun, 2024 1 commit
-
-
Phuong Nguyen authored
- Made order of gated act consistent in all branches Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com>
-
- 22 May, 2024 1 commit
-
-
Ming-Xu Huang authored
* Fixed the shape mismatching issue in MLP. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Add a corresponding test Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Co-authored-by:
Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
-
- 13 May, 2024 1 commit
-
-
Phuong Nguyen authored
* renamed gelu to act * added relu, srelu, qgelu * fixes initialization for layernorm_fp8_mlp tests * moved activation_fp8 prim into testunit file * Moved NVTE_Activation_Enum to common/.../activation.h --------- Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com>
-
- 03 May, 2024 1 commit
-
-
Phuong Nguyen authored
* templated primitives and respective C++ functions Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * fixes for LayerNormMLP, tests in test_custom_compute all passed Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * added default arg for pybind get_workspace_size funcs Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * fixes for TestTransFormer with non-gated act tests Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * renamed gelu to act Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * improved enum implementation, avoid using magic numbers Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * Exposed C++ ActivationEnum to python side Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * Changed error messages Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * changed conditional check on input shape for dbias_cast_transpose Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * changed dtype (tol) for bias grad tests Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * fixes so that layer_norm_fp8_mlp can take bias = None Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * Set bias = None in flax modules Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 02 May, 2024 1 commit
-
-
Reese Wang authored
* Add layernorm_fp8_dot unit test Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update the softmax primitives support conditions Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add tests for the softmax primitives Signed-off-by:
Reese Wang <rewang@nvidia.com> * Round1 refactor of test_layer Signed-off-by:
Reese Wang <rewang@nvidia.com> * Split dropout arguments of ref code and add hidden/intermediate dropout elementwise comparison Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add dropout_braodcast_dim, self_attn_mask tests and clean a few code Signed-off-by:
Reese Wang <rewang@nvidia.com> * Abstract test layer and fix a rope reference code diff Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add bias tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add epsilon and float32 tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add relpos_bias and attention dropout tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Loose the atol Signed-off-by:
Reese Wang <rewang@nvidia.com> * Move common fixtures to conftest.py Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add doc string for test_layer Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add doc string for test_layer Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix conflicts of test_layer Signed-off-by:
Reese Wang <rewang@nvidia.com> * Avoid to left bias parameters in graph when use_bias=False Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 24 Apr, 2024 2 commits
-
-
Phuong Nguyen authored
* Implemented swiglu and silu Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * Renamed nvte-*silu to nvte-*swish + generalized GetDBiasDact functions Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
Phuong Nguyen authored
* combined layernorm_geglu with layernorm_gelu into fused_layernorm Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * fixes to pass all unit tests in test_custom_call_compute.py, test_layer.py, and test_praxis_layer.py Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * cleaning and formatting Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * renaming based on reviewers suggestions Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * implemented partial fused layernorm Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * geglu + bias passed tests Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * added partial fused calculation for dbias_1 Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * clean up Co-authored-by:
Alp Dener <adener@nvidia.com> Signed-off-by:
Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> Signed-off-by:
Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> Co-authored-by:
Alp Dener <adener@nvidia.com>
-
- 16 Apr, 2024 1 commit
-
-
Ming-Xu Huang authored
-
- 02 Feb, 2024 1 commit
-
-
Ming-Xu Huang authored
* Adding support of sequence parallelism Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding RoPE Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix wrong batch_logical_axes Signed-off-by:
Ming Huang <mingh@nvidia.com> * Rnaming FSDP outer env var Signed-off-by:
Ming Huang <mingh@nvidia.com> * Poring RoPE to Praxis layers. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Porting GeLU + [FP8 Cast]. Signed-off-by:
Ming Huang <mingh@nvidia.com> * WAR to make XLA successfully match FP8 GEMM on FFN1 with GeLU. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Allowing arbitrary dimension of NVShape for the workspace allocation Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding checkpoint_name to fused functions of mlp.py to get better perf with nn.scan. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Modify with review feedback. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix bugs Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix typo. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fixed for lint Signed-off-by:
Ming Huang <mingh@nvidia.com> * Follow review feedback to modify code. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix typo. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Port SP to Praxis Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix an issue when enabling both GQA and RoPE. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Update docs Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com>
-
- 29 Jan, 2024 1 commit
-
-
Alp Dener authored
* Removed cudaMalloc/WorkspaceManager in JAX csrc. JAX custom ops now request buffers from XLA for their workspace tensors. Signed-off-by:
Alp Dener <adener@nvidia.com> * removed unused GEMM C++ API in TE-JAX Signed-off-by:
Alp Dener <adener@nvidia.com> * fixed typo in layernorm_geglu_fp8_mlp and removed unnecessary shape reductions in primitives Signed-off-by:
Alp Dener <adener@nvidia.com> * fixed import order for linting Signed-off-by:
Alp Dener <adener@nvidia.com> * fixed custom op errors due to incorrect static arg nums in JAX jit Signed-off-by:
Alp Dener <adener@nvidia.com> * shifted cudnnSetStream further down the kernel to avoid error when executing dummy kernel call with nullptr stream Signed-off-by:
Alp Dener <adener@nvidia.com> * fixed linting errors for blank lines Signed-off-by:
Alp Dener <adener@nvidia.com> --------- Signed-off-by:
Alp Dener <adener@nvidia.com>
-
- 12 Jan, 2024 1 commit
-
-
Ming-Xu Huang authored
* Adding Cast custom call Signed-off-by:
Ming Huang <mingh@nvidia.com> * Applying cast to the kernel of layernorm_fp8_dot Signed-off-by:
Ming Huang <mingh@nvidia.com> * Applying native cast to the kernel of fp8_dot. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Apply Cast and native cast to layernorm_geglu_fp8_dot Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix the bug to enable layernorm_geglu_fp8_dot in LayernormMlp Signed-off-by:
Ming Huang <mingh@nvidia.com> * Modifiied code with the review feedback. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding 2xACC control to FP8 GEMMs. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Set precision as an static arg Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com>
-
- 03 Jan, 2024 1 commit
-
-
Przemyslaw Tredak authored
Signed-off-by:Przemek Tredak <ptredak@nvidia.com>
-
- 04 Dec, 2023 1 commit
-
-
zlsh80826 authored
Add checkpoint_name Signed-off-by:Reese Wang <rewang@nvidia.com>
-
- 14 Nov, 2023 1 commit
-
-
Ming-Xu Huang authored
* Refactor sharding.py for the further custom_partitioning migration Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating both FWD and BWD of LayerNorm/RMSNorm from xmap to custom_partitioning. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating both FWD and BWD of all kinds of softmax from xmap to custom_partitioning. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix the wrong order of parameters to LN/RMSN bwd in ln_mlp_fp8. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * WAR to LN/RMSN_fp8 before migrating to CP. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix the wrong order of parameters of bwd of LN/RMSN_fp8. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Following review feedback to modify Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Force the hidden dim in Norm ops to no sharding and add warning msg. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Reuse fwd_rule in VJP functions Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating both FWD and BWD of self-fused-attn from xmap to custom_partitioning. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating both FWD and BWD of cross-fused-attn from xmap to custom_partitioning. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * add gelu and dgelu. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Reuse fwd_rule in VJP functions for attentions Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Apply native FP8 Dtypes to fp8.py Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating cast_and_transpose from xmap to custom_partitioning Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating transpose from xmap to custom_partitioning Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Apply XLA pattern match to perform FP8 GEMM. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * migrate layernorm_fp8 to custom_partitioning. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Unify code style Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Extend supported of Transpose with FP8 Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Implementing layernorm_fp8_dot based on migrated custom calls. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Renaming variables and publish NVTE_FP8_COLLECTION_NAME Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Replace Q/DQ custom calls with native XLA implementations Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * migrate gelu_fp to custom_partitioning. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Miner fix Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Support custom calls with mutli-dims Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Support gerneral dot indices in _fp8_dot_impl Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Implementing layernrom_geglu_fp8_mlp Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Remove GEMM custom calls Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Remove xmap related code Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix typo and add query-function to FP8MetaPackage Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix some bugs of custom calls Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix CT's bugs Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Update UTs/eaxmaples to adapt to the API changes. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Unify kernel initilization in MLP. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Modifing with code review's feedback Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Update README and Add deprecating warning to *ShardingType Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Canonicalize the dtype Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding assertion for non-supported batch dims. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding doc/examples to _multidim_transpose Signed-off-by:
Ming Huang <mingh@nvidia.com> * Set FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Set FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Apply dtype-based rtol/atol to UTs Signed-off-by:
Ming Huang <mingh@nvidia.com> * Deprecate QKV_INTERLEAVED enum Signed-off-by:
Ming Huang <mingh@nvidia.com> * Skip test_distributed_custom_ops.py Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix the wrong sharding of bias in SelfAttn Signed-off-by:
Ming Huang <mingh@nvidia.com> * WAR to fix the wrong cu_seqlen of MHA when DP/FSDP enabled Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding distributed ops unit-tests Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding license to test_distributed_* Signed-off-by:
Ming Huang <mingh@nvidia.com> * Follow review feedback to modify Signed-off-by:
Ming Huang <mingh@nvidia.com> * Use total bytes involved in collective ops as criteria. Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> Co-authored-by:
Donglin Yang <dongliny@nvidia.com>
-
- 13 Nov, 2023 1 commit
-
-
zlsh80826 authored
[C/JAX] Support more mask types for the arbitrary seqlen kernels and minor changes of JAX bias (#469) * Move bias to float32 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enable varlen Signed-off-by:
Reese Wang <rewang@nvidia.com> * Increase neg infinity abs values Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enable varlen tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove unnecessary code Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix lint Signed-off-by:
Reese Wang <rewang@nvidia.com> * Support variable sequence length after cuDNN 8.9.6 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use unique_ptr instead of shared_ptr Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add a new mask type: PADDING_CAUSAL_MASK Signed-off-by:
Reese Wang <rewang@nvidia.com> * Support flash padding mask after 8.9.6 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the Max512 handling for causal masking and add the related tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update the fused attn support lists Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove padding_aware from the caching Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix libtransformer.so issue Signed-off-by:
Reese Wang <rewang@nvidia.com> * Reduce the pad ratio tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix a bug with cuDNN 8.9.5 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Release backend resource after the module level unit test Signed-off-by:
Reese Wang <rewang@nvidia.com> * Clean the jax live arrays before running the unit tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix too-few-public-methods lint Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 06 Oct, 2023 1 commit
-
-
Ming-Xu Huang authored
* [JAX] Enhance Dropout in TransformerLayer. 1. Fixed missing setup of dropout RNG key in TransformerLayer and LayerNormMLP. 2. Allowing seperated dropout rate for FC1's output and other hiddens. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix wrong fp8 scale in _update_fp8_metas_impl Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix typo Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 09 May, 2023 1 commit
-
-
Ming-Xu Huang authored
[JAX] Fix missing axes parameters in TransformerLayer and the wrong shape of bias in LayerNormMLP (#196) Fixed missing axes and wrong shape of bias in LayerNormMLP Signed-off-by:Ming Huang <mingh@nvidia.com>
-
- 28 Apr, 2023 1 commit
-
-
Ming-Xu Huang authored
* Adjust Module Structure. 1. Collect Flax related modules to a sub-folder, flax. 2. Add a function to unify scale_init for zero-centered-gamma LN. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Make changes be compatible to previous versions. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adapt jax/examples to the new module structure. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Update jax/docs and Add deprecated warning. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Update README Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding deprecated_wrapper Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding deprecated warning to flax modules which imported via transformer_engine.jax Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix CI errors and update docs. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Removing unnecessary deprecated warning in docs. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Implementing __iter__ to DeprecatedEnum. Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 20 Apr, 2023 1 commit
-
-
Ming-Xu Huang authored
* Allow update_collections and update_fp8_metas to return both Dict and FrozenDict. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix the wrong shape issue of bias when fused QKV or KV. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Reuse tuplized features for bias creating. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Replace get_args to be more readable. Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com>
-
- 13 Apr, 2023 1 commit
-
-
zlsh80826 authored
* Add zero_center_gamma/functional pass Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add zero_centered_gamma for fp8_ln_mlp Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add zero_centered_gamma to modules Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add zero_centered_gamma to TransformerLayer Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refactored code style for improved readability and consistency Signed-off-by:
Reese Wang <rewang@nvidia.com> * Docs enhancement for zero_centered_gamma Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add escape for line break and remove some bad if conditions Signed-off-by:
Reese Wang <rewang@nvidia.com> * Revise scale_init docs Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com>
-
- 07 Apr, 2023 1 commit
-
-
Ming-Xu Huang authored
* Rename enable_fp8 to is_fp8_enabled. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding an API to get an instance of DelayedScaling which is set via fp8_autocast. Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com>
-
- 29 Mar, 2023 1 commit
-
-
Ming-Xu Huang authored
* Support transpose_bs when decoded=True Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix Bugs, 1. Fix missing dropout_dims in LayerNormMLP. 2. Fix broadcast issues in decoded. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix wrong masks in decoded. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fixed wrong assert condition in TransformerLayer Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix amax is not set as 0 in each step. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Enhance rules conflict checking and docs. Signed-off-by:
Ming Huang <mingh@nvidia.com> * fix code formatting. Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> Signed-off-by:
Ming Huang <mingh@nvidia.com>
-
- 28 Mar, 2023 1 commit
-
-
Jeng Bai-Cheng authored
* refactor JAX examples Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix doc-string Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add dp example Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix params_axes_pspec Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Add model parallel example and refactor Update readme Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * align code and readme Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * update verification Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add mask Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * num_gpu is configurable Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * update readme Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * update readme Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * solvepylint issue Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * ignore markdown and txt file from license check Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update README.md Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add flax into requirements.txt Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> --------- Signed-off-by:
Ryan Jeng <rjeng@nvidia.com>
-
- 14 Mar, 2023 1 commit
-
-
Ming-Xu Huang authored
* Updated TE/JAX docs Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding TE/JAX docs' rst files Signed-off-by:
Ming Huang <mingh@nvidia.com> * Set DType as pybind11::module_local() to avoid generic_type errors. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Updating license and exporting more modules Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adopting autoapi and removing enum_tools. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix typo Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Make jax.rst be style consistent. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fixing doc statements as the suggestion from review. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fixing doc statements as the suggestion from code review. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Update the description of Softmax Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Removed categories in catalog as PyTorch Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 09 Mar, 2023 1 commit
-
-
Jeng Bai-Cheng authored
* add transformer module , unittests and examples Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update tests/jax/test_sharding.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update transformer_engine/jax/transformer.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * remove pylint: disable=line-too-long Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove pylint: disable=too-many-func-args Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Fix the wrong broadcasting dim to dropout masks when enable transpose_bs. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Enable 2xACC for WGRAD and DGRAD by default Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * rename LayerNormMlpBlock as LayerNormMLP Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor to avoid line-too-long Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * rename amax_history_size to amax_history_len Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * align dropout mask to TE/PyTorch as default Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * enlarge atol for decoder unittests Two decoder unittests can pass in old JAX container(e.g., 23.02) but can't in latest container (devel). 1. The actual(-0.020264) and desired(-0.020386) are very close. 2. The TE kernels are not changed, the diff should come from new codegen behavior of XLA. Thus, it is a common floating-point accumulated error. Enlarge atol to avoid unittest failures. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Adding Amax History Support 1. hide amax update in custom_vjp 2. replace amax indexing with roll(using circular buffer) Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * move kernel_init to __post_init__ Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor encoder examples Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update transformer_engine/jax/fp8.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update transformer_engine/jax/fp8.py Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * remove envvar regarding 2xACC Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove unused import Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> --------- Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by:
Ming-Xu Huang <mingh@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-