- 02 Jan, 2025 1 commit
-
-
Kirthi Shankar Sivamani authored
Signed-off-by:Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 15 Jun, 2024 1 commit
-
-
Phuong Nguyen authored
* rm tensor check if the workspace is empty Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * add trust_remote=true for load_dataset() in the mnist test Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-
- 14 Jun, 2024 1 commit
-
-
Kirthi Shankar Sivamani authored
* Apply formatting Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Apply formatting Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 19 Apr, 2024 1 commit
-
-
Ming-Xu Huang authored
* Allow multi-dims for dgamma and dbeta in LN descriptor. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix the jit error in examples/jax Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com>
-
- 18 Apr, 2024 1 commit
-
-
Alp Dener authored
* fixed static argnums for jax.jit in single gpu encoder test, changed warning filtering for pytest Signed-off-by:
Alp Dener <adener@nvidia.com> * propagating the fix to the JAX mnist example Signed-off-by:
Alp Dener <adener@nvidia.com> * fixed missing space ibetween flags i QAA scripts Signed-off-by:
Alp Dener <adener@nvidia.com> * added TE warnings into the ignore list Signed-off-by:
Alp Dener <adener@nvidia.com> --------- Signed-off-by:
Alp Dener <adener@nvidia.com>
-
- 03 Jan, 2024 1 commit
-
-
Przemyslaw Tredak authored
Signed-off-by:Przemek Tredak <ptredak@nvidia.com>
-
- 14 Nov, 2023 1 commit
-
-
Ming-Xu Huang authored
* Refactor sharding.py for the further custom_partitioning migration Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating both FWD and BWD of LayerNorm/RMSNorm from xmap to custom_partitioning. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating both FWD and BWD of all kinds of softmax from xmap to custom_partitioning. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix the wrong order of parameters to LN/RMSN bwd in ln_mlp_fp8. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * WAR to LN/RMSN_fp8 before migrating to CP. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix the wrong order of parameters of bwd of LN/RMSN_fp8. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Following review feedback to modify Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Force the hidden dim in Norm ops to no sharding and add warning msg. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Reuse fwd_rule in VJP functions Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating both FWD and BWD of self-fused-attn from xmap to custom_partitioning. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating both FWD and BWD of cross-fused-attn from xmap to custom_partitioning. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * add gelu and dgelu. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Reuse fwd_rule in VJP functions for attentions Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Apply native FP8 Dtypes to fp8.py Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating cast_and_transpose from xmap to custom_partitioning Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating transpose from xmap to custom_partitioning Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Apply XLA pattern match to perform FP8 GEMM. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * migrate layernorm_fp8 to custom_partitioning. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Unify code style Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Extend supported of Transpose with FP8 Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Implementing layernorm_fp8_dot based on migrated custom calls. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Renaming variables and publish NVTE_FP8_COLLECTION_NAME Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Replace Q/DQ custom calls with native XLA implementations Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * migrate gelu_fp to custom_partitioning. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Miner fix Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Support custom calls with mutli-dims Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Support gerneral dot indices in _fp8_dot_impl Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Implementing layernrom_geglu_fp8_mlp Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Remove GEMM custom calls Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Remove xmap related code Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix typo and add query-function to FP8MetaPackage Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix some bugs of custom calls Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix CT's bugs Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Update UTs/eaxmaples to adapt to the API changes. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Unify kernel initilization in MLP. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Modifing with code review's feedback Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Update README and Add deprecating warning to *ShardingType Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Canonicalize the dtype Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding assertion for non-supported batch dims. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding doc/examples to _multidim_transpose Signed-off-by:
Ming Huang <mingh@nvidia.com> * Set FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Set FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Apply dtype-based rtol/atol to UTs Signed-off-by:
Ming Huang <mingh@nvidia.com> * Deprecate QKV_INTERLEAVED enum Signed-off-by:
Ming Huang <mingh@nvidia.com> * Skip test_distributed_custom_ops.py Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix the wrong sharding of bias in SelfAttn Signed-off-by:
Ming Huang <mingh@nvidia.com> * WAR to fix the wrong cu_seqlen of MHA when DP/FSDP enabled Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding distributed ops unit-tests Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding license to test_distributed_* Signed-off-by:
Ming Huang <mingh@nvidia.com> * Follow review feedback to modify Signed-off-by:
Ming Huang <mingh@nvidia.com> * Use total bytes involved in collective ops as criteria. Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> Co-authored-by:
Donglin Yang <dongliny@nvidia.com>
-
- 08 Nov, 2023 1 commit
-
-
Tim Moon authored
* Use FP8 tolerances in JAX FP8 tests Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Programmatically compute expected floating point error Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Loosen tolerance for MNIST test Signed-off-by:
Tim Moon <tmoon@nvidia.com> --------- Signed-off-by:
Tim Moon <tmoon@nvidia.com>
-
- 03 Aug, 2023 1 commit
-
-
Ming-Xu Huang authored
* Cast Flax collections to FrozenDict as WAR to adapt Flax 0.7.1. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding min version of flax to requirements.txt in examples Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix praxis tests and rename compare_frozen_dict to compare_dict Signed-off-by:
Ming Huang <mingh@nvidia.com> * [Paddle] Refactor FP8 state (#350) Refactor fp8 state Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> Signed-off-by:
Ming Huang <mingh@nvidia.com> * Store FP8 checkpointing data in CPU (#351) Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
Ming Huang <mingh@nvidia.com> * Make test_layer be able to run on both Flax >=0.7.1 and <=0.7.0 Signed-off-by:
Ming Huang <mingh@nvidia.com> * Update Flax version Signed-off-by:
Tim Moon <tmoon@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
Tim Moon <tmoon@nvidia.com> Co-authored-by:
Tian Zheng <tizheng@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Tim Moon <tmoon@nvidia.com>
-
- 16 May, 2023 1 commit
-
-
Ming-Xu Huang authored
* Adding JAX/Praxis modules and dependencies. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding UTs to JAX/Praxis modules. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Remove praxis as a dependency due to not strictly needed Signed-off-by:
Ming Huang <mingh@nvidia.com> * Repalce is_fp8_supported to is_fp8_available Signed-off-by:
Ming Huang <mingh@nvidia.com> * Make Praxis as an optional dependency. 1. Removed 'from . import praxis' in __init__.py. 1.1 Noted, keep 'from . import flax' for deprecated warning. 2. Changed te.flax to te_flax in examples and README.rst. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding a workaround to FP8 training on Praxis. Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com>
-
- 09 May, 2023 1 commit
-
-
Jeng Bai-Cheng authored
* add mp example Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * update doc-string Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * better FP8 checker Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * update readme Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * replace te.* with te.flax* to remove deprecated warning Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove nouse os.environ Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove nouse Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix typo Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/README.md Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/README.md Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/README.md Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/README.md Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/test_multiprocessing_encoder.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove cuda-python Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * adjust readme Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/README.md Signed-off-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix cpp lint fix issue of "Could not find a newline character at the end of the file." Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix AssertionError: 1 GPU per process Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * replace tfds with datasets The Flax application crash if it use TensorFlow Dataset (tfds) in NVIDIA JAX container. The tfds is very useful for downloading well-knwon dataset (e.g., MNIST, GLUE) and commonly used by TF/JAX community. However, it seems like that it is NOT compatible with NVIDIA TensorFlow in NVIDIA JAX container and somehow affects JAX. It triggers random errors at JAX initialization depending on different versions, and make CI unstable. Thus, this commit replaces tfds with "huggingface datasets" to download needed datasets. See "nvbugs 4039266" for more details. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix input sharding Unlike SPMD mode, in multiprocessing mode, the input tensor must be sharded manually. Using DP=4, TP=2 as an example, the device mesh looks like: mesh.device_ids = [[0, 1], [2, 3], [4, 5], [6, 7]] Assume that the process ID is mapped to GPU ID. The process 0 and process 1 are grouped for model parallelism, process 2 and process 3 are grouped together too, and so on. The process 0 and process 1 need to share the same micro-batch in the training step, process 0 and process 2, 4, and 6 have different micro-batch. Thus, `shard_array_wrapper` partitions inputs to 4 parts (and setup needed arguments for jax.make_array_from_single_device_arrays). The process 0 and process 1 take the first quarter, process 2 and process 3 take the second quarter, and so on. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor UT for multiprocess example Use Python `multiprocessing` to test the multiprocessing example, if the system has multiple GPU. 1 GPU per process. Because `jax.distributed.initialize` must be called before any other JAX or Flax API, GPU info cannot be queried by calling jax.local_devices() in TestEncoder. Thus, `unittest_query_gpu()` forks another process to query number of GPUs and FP8 capability. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove nouse arg `--num-gpu` JAX doesn't have an API to setup number of GPU used in SPMD mode. The only way is to use `CUDA_VISIBLE_DEVICES` for now. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix typo Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix ut Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * simplify the mask setting Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * increase batch-size for multigpu example The batch-size 64 is too small to be partitioned for 8xH100. If batch-size is 64, the GEMM shape is 256x8192x8 per GPU. The 8 is too small for FP8 GEMM kernel, and cuBLASLt will throw "Failed to query heuristics". Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix downloading mnist error To download MNIST via `huggingface datasets`, it requires Pillow. Otherwise, it throws `An error occurred while generating the dataset` Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> --------- Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com>
-
- 28 Apr, 2023 1 commit
-
-
Ming-Xu Huang authored
* Adjust Module Structure. 1. Collect Flax related modules to a sub-folder, flax. 2. Add a function to unify scale_init for zero-centered-gamma LN. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Make changes be compatible to previous versions. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adapt jax/examples to the new module structure. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Update jax/docs and Add deprecated warning. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Update README Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding deprecated_wrapper Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding deprecated warning to flax modules which imported via transformer_engine.jax Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix CI errors and update docs. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Removing unnecessary deprecated warning in docs. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Implementing __iter__ to DeprecatedEnum. Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 28 Mar, 2023 1 commit
-
-
Jeng Bai-Cheng authored
* refactor JAX examples Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix doc-string Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add dp example Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix params_axes_pspec Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Add model parallel example and refactor Update readme Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * align code and readme Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * update verification Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add mask Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * num_gpu is configurable Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * update readme Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * update readme Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * solvepylint issue Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * ignore markdown and txt file from license check Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update README.md Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add flax into requirements.txt Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> --------- Signed-off-by:
Ryan Jeng <rjeng@nvidia.com>
-