- 16 May, 2023 1 commit
-
-
Ming-Xu Huang authored
* Adding JAX/Praxis modules and dependencies. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding UTs to JAX/Praxis modules. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Remove praxis as a dependency due to not strictly needed Signed-off-by:
Ming Huang <mingh@nvidia.com> * Repalce is_fp8_supported to is_fp8_available Signed-off-by:
Ming Huang <mingh@nvidia.com> * Make Praxis as an optional dependency. 1. Removed 'from . import praxis' in __init__.py. 1.1 Noted, keep 'from . import flax' for deprecated warning. 2. Changed te.flax to te_flax in examples and README.rst. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding a workaround to FP8 training on Praxis. Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com>
-
- 12 May, 2023 1 commit
-
-
Jeng Bai-Cheng authored
bugfix for softmax lowering Signed-off-by:Ryan Jeng <rjeng@nvidia.com>
-
- 09 May, 2023 1 commit
-
-
Jeng Bai-Cheng authored
* add mp example Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * update doc-string Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * better FP8 checker Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * update readme Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * replace te.* with te.flax* to remove deprecated warning Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove nouse os.environ Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove nouse Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix typo Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/README.md Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/README.md Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/README.md Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/README.md Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/test_multiprocessing_encoder.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove cuda-python Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * adjust readme Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/README.md Signed-off-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix cpp lint fix issue of "Could not find a newline character at the end of the file." Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix AssertionError: 1 GPU per process Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * replace tfds with datasets The Flax application crash if it use TensorFlow Dataset (tfds) in NVIDIA JAX container. The tfds is very useful for downloading well-knwon dataset (e.g., MNIST, GLUE) and commonly used by TF/JAX community. However, it seems like that it is NOT compatible with NVIDIA TensorFlow in NVIDIA JAX container and somehow affects JAX. It triggers random errors at JAX initialization depending on different versions, and make CI unstable. Thus, this commit replaces tfds with "huggingface datasets" to download needed datasets. See "nvbugs 4039266" for more details. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix input sharding Unlike SPMD mode, in multiprocessing mode, the input tensor must be sharded manually. Using DP=4, TP=2 as an example, the device mesh looks like: mesh.device_ids = [[0, 1], [2, 3], [4, 5], [6, 7]] Assume that the process ID is mapped to GPU ID. The process 0 and process 1 are grouped for model parallelism, process 2 and process 3 are grouped together too, and so on. The process 0 and process 1 need to share the same micro-batch in the training step, process 0 and process 2, 4, and 6 have different micro-batch. Thus, `shard_array_wrapper` partitions inputs to 4 parts (and setup needed arguments for jax.make_array_from_single_device_arrays). The process 0 and process 1 take the first quarter, process 2 and process 3 take the second quarter, and so on. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor UT for multiprocess example Use Python `multiprocessing` to test the multiprocessing example, if the system has multiple GPU. 1 GPU per process. Because `jax.distributed.initialize` must be called before any other JAX or Flax API, GPU info cannot be queried by calling jax.local_devices() in TestEncoder. Thus, `unittest_query_gpu()` forks another process to query number of GPUs and FP8 capability. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove nouse arg `--num-gpu` JAX doesn't have an API to setup number of GPU used in SPMD mode. The only way is to use `CUDA_VISIBLE_DEVICES` for now. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix typo Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix ut Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * simplify the mask setting Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * increase batch-size for multigpu example The batch-size 64 is too small to be partitioned for 8xH100. If batch-size is 64, the GEMM shape is 256x8192x8 per GPU. The 8 is too small for FP8 GEMM kernel, and cuBLASLt will throw "Failed to query heuristics". Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix downloading mnist error To download MNIST via `huggingface datasets`, it requires Pillow. Otherwise, it throws `An error occurred while generating the dataset` Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> --------- Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com>
-
- 28 Apr, 2023 1 commit
-
-
Ming-Xu Huang authored
* Adjust Module Structure. 1. Collect Flax related modules to a sub-folder, flax. 2. Add a function to unify scale_init for zero-centered-gamma LN. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Make changes be compatible to previous versions. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adapt jax/examples to the new module structure. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Update jax/docs and Add deprecated warning. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Update README Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding deprecated_wrapper Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding deprecated warning to flax modules which imported via transformer_engine.jax Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix CI errors and update docs. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Removing unnecessary deprecated warning in docs. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Implementing __iter__ to DeprecatedEnum. Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 28 Mar, 2023 1 commit
-
-
Jeng Bai-Cheng authored
* refactor JAX examples Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix doc-string Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add dp example Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix params_axes_pspec Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Add model parallel example and refactor Update readme Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * align code and readme Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * update verification Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add mask Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * num_gpu is configurable Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * update readme Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * update readme Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * solvepylint issue Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * ignore markdown and txt file from license check Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update README.md Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add flax into requirements.txt Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> --------- Signed-off-by:
Ryan Jeng <rjeng@nvidia.com>
-