- 05 Sep, 2023 1 commit
-
-
Frédéric Bastien authored
Use the new API when it is available. Signed-off-by:Frederic Bastien <fbastien@nvidia.com>
-
- 30 Aug, 2023 1 commit
-
-
Ming-Xu Huang authored
* [JAX] Fix incorrect sharding when only enable FSDP. Signed-off-by:
Ming Huang <mingh@nvidia.com> * [JAX] Add WAR to memory misaligned issues of LN BWD. Signed-off-by:
Ming Huang <mingh@nvidia.com> * [JAX] Reuse sm_arch for avoiding duplicate code. Signed-off-by:
Ming Huang <mingh@nvidia.com> * [JAX] Support multiple sizes allocation in WorkspaceManager. Signed-off-by:
Ming Huang <mingh@nvidia.com> * [JAX] Use template and ariadic arguments to improve multple sizes allocator. Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com>
-
- 25 Aug, 2023 1 commit
-
-
zlsh80826 authored
* Fused attention kernel only supports sm80 and sm90 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update transformer_engine/jax/csrc/modules.cpp Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * arbitary fused kernel supports sm86/sm89 after 8.9.3 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Skip sm70 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Forward is_fused_attn_kernel_available to cpp backend Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove cpp is_fused_attn_available API Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com>
-
- 09 Aug, 2023 1 commit
-
-
Ming-Xu Huang authored
* Initially commit for FSDP Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding support to fsdp xmap sharding Signed-off-by:
Ming Huang <mingh@nvidia.com> * Specify WeightHParamsCollection of fp8 meta. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Support partial FP8 custom calls with FSDP. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding amax reduction on the fsdp mesh dim. Signed-off-by:
Ming Huang <mingh@nvidia.com> * clean code Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix the wrong batch axis in logic_axis_rules and add sharding_constraint to BMM1 Signed-off-by:
Ming Huang <mingh@nvidia.com> * Support FSDP in fMHA. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix missing all-reduce of wgrads along FSDP axis. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Change default value of fsdp_axis_name to for aligning with others Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix RuntimeError: with_sharding_constraint requires a non-empty Signed-off-by:
Ming Huang <mingh@nvidia.com> * Slightly changes (review feedback) Signed-off-by:
Ming Huang <mingh@nvidia.com> * Removed unnecessary comments Signed-off-by:
Ming Huang <mingh@nvidia.com> * Mergeing input_dp_dim into weight_fsdp_dim_map Signed-off-by:
Ming Huang <mingh@nvidia.com> * Update transformer_engine/jax/sharding.py Signed-off-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com>
-
- 07 Aug, 2023 1 commit
-
-
zlsh80826 authored
* Fix flash attention dropout probability with inference Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add output as the fused attention ctx tensor Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add rng_state as the fused attention ctx tensors Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add flash attention supported lengths to the fused attention Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refactor attention primitive to reuse abstract shaped array Signed-off-by:
Reese Wang <rewang@nvidia.com> * Detect backend type to allocate appropriate ctx size Signed-off-by:
Reese Wang <rewang@nvidia.com> * Skip dropout correctness instead of return success Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use cudaMemsetAsync and enhance the error handling Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add flash attention kernel elts_per_thread update Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove redundant max 512 suffix Signed-off-by:
Reese Wang <rewang@nvidia.com> * Keep only DType and remove NVTEDType from python Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix a float32_attention_logits bugs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Re-calculate workspace size for self attention Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance bias/dbias shape guard Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the seed/rng_state checker Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use jax.core.ShapedArray as jax.abstract_arrays is deprecated Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the unittest docs Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 18 Jul, 2023 1 commit
-
-
zlsh80826 authored
* Fully remove attn_type and set self_attn_mask_type default to 'causal' Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix tests with new arguments Signed-off-by:
Reese Wang <rewang@nvidia.com> * Explicit self_attn_mask_type for examples Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update transformer_engine/jax/flax/transformer.py Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
zlsh80826 <rewang@nvidia.com> * Update transformer_engine/jax/flax/transformer.py Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
zlsh80826 <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Signed-off-by:
zlsh80826 <rewang@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 07 Jul, 2023 1 commit
-
-
Ming-Xu Huang authored
Signed-off-by:Ming Huang <mingh@nvidia.com>
-
- 20 Jun, 2023 2 commits
-
-
zlsh80826 authored
* Enable fused attention dropout Signed-off-by:
Reese Wang <rewang@nvidia.com> * Cast the uint32 key/counter to int64 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update dropout support in fused attention docs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Revise devPtrCuSeqlen* to align the naming Signed-off-by:
Reese Wang <rewang@nvidia.com> * Support different Jax PRNG impls Signed-off-by:
Reese Wang <rewang@nvidia.com> * Revert CastAsync since it is not used Signed-off-by:
Reese Wang <rewang@nvidia.com> * Implement is_training for 16-bit fused attn Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add fused attn with dropout sanity unit tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the comments readability and rng_state checker Signed-off-by:
Reese Wang <rewang@nvidia.com> * Change the attention dropout shape to align other frameworks Signed-off-by:
Reese Wang <rewang@nvidia.com> * Make encoder tests deterministic Signed-off-by:
Reese Wang <rewang@nvidia.com> * Change the default seed for the jax encoder tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Maintain offset in TE Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the resource safety Signed-off-by:
Reese Wang <rewang@nvidia.com> * Revert rng_state type to allow only i64 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Handle the corner case for elts_per_threads calculation Signed-off-by:
Reese Wang <rewang@nvidia.com> * Populate rng state by kernels Signed-off-by:
Reese Wang <rewang@nvidia.com> * Rename rng_state as seed in cpp_extensions Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update the attention dropout comment Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
zlsh80826 authored
* Add self_attn_mask_type and replace attn_type Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refine the keyword style for the better readability Signed-off-by:
Reese Wang <rewang@nvidia.com> * Replace attn_type with attn_mask_type in praxis transformer Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix typos Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 07 Jun, 2023 1 commit
-
-
Frédéric Bastien authored
* Use the same default in the function to what the class default. Signed-off-by:
Frederic Bastien <fbastien@nvidia.com> * Assert instead of silently ignoring not supported variation. Small doc correction, amax_compute_algo is partially supported. Signed-off-by:
Frederic Bastien <fbastien@nvidia.com> * Fix line lenght to fix the CI. Signed-off-by:
Frederic Bastien <fbastien@nvidia.com> * Apply suggestions from code review Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Frédéric Bastien <frederic.bastien@gmail.com> * grammar Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Clarify that it is only TE/JAX that don't support that faeture. Signed-off-by:
Frederic Bastien <fbastien@nvidia.com> * Update transformer_engine/jax/fp8.py Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
Frédéric Bastien <frederic.bastien@gmail.com> * Update the test following the change in default value Signed-off-by:
Frederic Bastien <fbastien@nvidia.com> * Fix ci Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Frederic Bastien <fbastien@nvidia.com> Signed-off-by:
Frédéric Bastien <frederic.bastien@gmail.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 06 Jun, 2023 1 commit
-
-
Ming-Xu Huang authored
Signed-off-by:Ming Huang <mingh@nvidia.com>
-
- 02 Jun, 2023 1 commit
-
-
Jan Bielak authored
* Ignore IDE files Signed-off-by:
Jan Bielak <jbielak@nvidia.com> * Fix typing errors Signed-off-by:
Jan Bielak <jbielak@nvidia.com> * Ignore devcontainer files Signed-off-by:
Jan Bielak <jbielak@nvidia.com> * Avoid import from private module Signed-off-by:
Jan Bielak <jbielak@nvidia.com> * Apply @timmoon10 's suggestions Signed-off-by:
Jan Bielak <jbielak@nvidia.com> --------- Signed-off-by:
Jan Bielak <jbielak@nvidia.com>
-
- 31 May, 2023 1 commit
-
-
Tim Moon authored
* Refactor Setuptools build system Successfully launches CMake install, but installs CMake extensions in temp dir. Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Debug JAX build Fix pybind11 import. Distinguish between build-time and run-time dependencies. Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Add helper function to determine dependencies Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Add missing license Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Debug case where system CMake is too old Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Add missing license Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Simplify sanity import tests Just importing modules provides richer error messages. Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Properly install submodules Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Install helper library for TensorFlow Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Update documentation Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Do not install Ninja by default Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Include Git commit hash in version string Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Override build_ext.build_extensions instead of build_ext.run Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Fix incorrect include path Restore Ninja dependency. Restore overriding build_ext.run func. Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Review suggestions from @nouiz Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Disable parallel Ninja jobs in GitHub actions Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Properly install userbuffers lib Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Tweak install docs Review suggestion from @ksivaman Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Add examples for specifying framework in docs Signed-off-by:
Tim Moon <tmoon@nvidia.com> --------- Signed-off-by:
Tim Moon <tmoon@nvidia.com>
-
- 23 May, 2023 1 commit
-
-
zlsh80826 authored
* Unfused scale+softmax if bias is present Signed-off-by:
Reese Wang <rewang@nvidia.com> * WAR a causal masking + no_bias bug and add the unittests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix the optional args (bias) sharding Signed-off-by:
Reese Wang <rewang@nvidia.com> * Disable fused attn in JAX by default, enable it with NVTE_USE_FUSED_ATTN Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add thread local for the plan cache Signed-off-by:
Reese Wang <rewang@nvidia.com> * Rename dbeta to dbias for the readability Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add scaled softmax with dropout test cases Signed-off-by:
Reese Wang <rewang@nvidia.com> * Updated NVTE_FUSED_ATTN variable name Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 16 May, 2023 2 commits
-
-
Frédéric Bastien authored
Signed-off-by:Frederic Bastien <fbastien@nvidia.com>
-
Ming-Xu Huang authored
* Adding JAX/Praxis modules and dependencies. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding UTs to JAX/Praxis modules. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Remove praxis as a dependency due to not strictly needed Signed-off-by:
Ming Huang <mingh@nvidia.com> * Repalce is_fp8_supported to is_fp8_available Signed-off-by:
Ming Huang <mingh@nvidia.com> * Make Praxis as an optional dependency. 1. Removed 'from . import praxis' in __init__.py. 1.1 Noted, keep 'from . import flax' for deprecated warning. 2. Changed te.flax to te_flax in examples and README.rst. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding a workaround to FP8 training on Praxis. Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com>
-
- 12 May, 2023 1 commit
-
-
Jeng Bai-Cheng authored
bugfix for softmax lowering Signed-off-by:Ryan Jeng <rjeng@nvidia.com>
-
- 09 May, 2023 3 commits
-
-
Ming-Xu Huang authored
[JAX] Fix missing axes parameters in TransformerLayer and the wrong shape of bias in LayerNormMLP (#196) Fixed missing axes and wrong shape of bias in LayerNormMLP Signed-off-by:Ming Huang <mingh@nvidia.com>
-
Jeng Bai-Cheng authored
* add mp example Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * update doc-string Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * better FP8 checker Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * update readme Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * replace te.* with te.flax* to remove deprecated warning Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove nouse os.environ Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove nouse Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix typo Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/README.md Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/README.md Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/README.md Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/README.md Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/test_multiprocessing_encoder.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove cuda-python Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * adjust readme Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/README.md Signed-off-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix cpp lint fix issue of "Could not find a newline character at the end of the file." Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix AssertionError: 1 GPU per process Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * replace tfds with datasets The Flax application crash if it use TensorFlow Dataset (tfds) in NVIDIA JAX container. The tfds is very useful for downloading well-knwon dataset (e.g., MNIST, GLUE) and commonly used by TF/JAX community. However, it seems like that it is NOT compatible with NVIDIA TensorFlow in NVIDIA JAX container and somehow affects JAX. It triggers random errors at JAX initialization depending on different versions, and make CI unstable. Thus, this commit replaces tfds with "huggingface datasets" to download needed datasets. See "nvbugs 4039266" for more details. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix input sharding Unlike SPMD mode, in multiprocessing mode, the input tensor must be sharded manually. Using DP=4, TP=2 as an example, the device mesh looks like: mesh.device_ids = [[0, 1], [2, 3], [4, 5], [6, 7]] Assume that the process ID is mapped to GPU ID. The process 0 and process 1 are grouped for model parallelism, process 2 and process 3 are grouped together too, and so on. The process 0 and process 1 need to share the same micro-batch in the training step, process 0 and process 2, 4, and 6 have different micro-batch. Thus, `shard_array_wrapper` partitions inputs to 4 parts (and setup needed arguments for jax.make_array_from_single_device_arrays). The process 0 and process 1 take the first quarter, process 2 and process 3 take the second quarter, and so on. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor UT for multiprocess example Use Python `multiprocessing` to test the multiprocessing example, if the system has multiple GPU. 1 GPU per process. Because `jax.distributed.initialize` must be called before any other JAX or Flax API, GPU info cannot be queried by calling jax.local_devices() in TestEncoder. Thus, `unittest_query_gpu()` forks another process to query number of GPUs and FP8 capability. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove nouse arg `--num-gpu` JAX doesn't have an API to setup number of GPU used in SPMD mode. The only way is to use `CUDA_VISIBLE_DEVICES` for now. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix typo Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix ut Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * simplify the mask setting Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * increase batch-size for multigpu example The batch-size 64 is too small to be partitioned for 8xH100. If batch-size is 64, the GEMM shape is 256x8192x8 per GPU. The 8 is too small for FP8 GEMM kernel, and cuBLASLt will throw "Failed to query heuristics". Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix downloading mnist error To download MNIST via `huggingface datasets`, it requires Pillow. Otherwise, it throws `An error occurred while generating the dataset` Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> --------- Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com>
-
zlsh80826 authored
* Add fused attention unit tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use NVTE_* enums Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use NVTE_Mask_Type and remove FMHADescriptor Signed-off-by:
Reese Wang <rewang@nvidia.com> * Move common functions to utils Signed-off-by:
Reese Wang <rewang@nvidia.com> * Change namespace to fused_attn Signed-off-by:
Reese Wang <rewang@nvidia.com> * Move fused_attn_max_512_fwd_qkvpacked under the general APIs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add fused_attn_max_512_bwd_qkvpacked Signed-off-by:
Reese Wang <rewang@nvidia.com> * Move fused_attn_max_512_bwd_qkvpacked under the general APIs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove redundant blank line Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix a potential bug for cu_seqlen converter Signed-off-by:
Reese Wang <rewang@nvidia.com> * Reformat fused_attn_max_512 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refine the unfused attention warning message Signed-off-by:
Reese Wang <rewang@nvidia.com> * Rename to fused_attn_max_512 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove the deprecated header Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix flax import Signed-off-by:
Reese Wang <rewang@nvidia.com> * Rename to fused attn Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add attention related mask Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add attn_mask_type and attn_bias_type Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refactor jax primitive API * Merge q_cu_seqlen and kv_cu_seqlen * Remove is_causal_masking * Replace seed with rng_state * Add is_training argument Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove dsoftmax from the customcall Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add None guard for bias and dropout_rng Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add version guard Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add is_fused_attn_kernel_available() to correctly dispatch the attention impl Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix the merge conflict Signed-off-by:
Reese Wang <rewang@nvidia.com> * Adjust the code style Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add the missing blank lines Signed-off-by:
Reese Wang <rewang@nvidia.com> * Change the order of FADescriptor members Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the readability of fused_attn_max_512.cu Signed-off-by:
Reese Wang <rewang@nvidia.com> * Generalize the input dimension unpacking Signed-off-by:
Reese Wang <rewang@nvidia.com> * 16 bits fused attention requires 8.9.1 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update fused attention support matrix Signed-off-by:
Reese Wang <rewang@nvidia.com> * Handle None type when sharding Signed-off-by:
Reese Wang <rewang@nvidia.com> * Change to the padding ratio Signed-off-by:
Reese Wang <rewang@nvidia.com> * Performance optimization for non-bias cases Signed-off-by:
Reese Wang <rewang@nvidia.com> * Revert the cudnn-frontend PRIVATE keyword which was used for debugging Signed-off-by:
Reese Wang <rewang@nvidia.com> * Revert "Update fused attention support matrix" This reverts commit 4effe67d0f08f733919a329ce5ab421958740f4a. Signed-off-by:
Reese Wang <rewang@nvidia.com> * Treat b * s as total_seqs to align ragged cases Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add FP16/BF16 max_seqlen <= 512 fused attention to the support matrix Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refine test_fused_attn.py * Replace reference code with flax.linen * Remove unnecessary comments * Use AttnMaskType Signed-off-by:
Reese Wang <rewang@nvidia.com> * Unify the cuDNN compile version Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add dropout to the support matrix Signed-off-by:
Reese Wang <rewang@nvidia.com> * Slightly adjust the headers Signed-off-by:
Reese Wang <rewang@nvidia.com> * Typo fix: remove redundant either Signed-off-by:
Reese Wang <rewang@nvidia.com> * Consolidating fused attention requirements Signed-off-by:
Reese Wang <rewang@nvidia.com> * Replace cudnn_frontend::throw_if with NVTE_CHECK for the better error line report Signed-off-by:
Reese Wang <rewang@nvidia.com> * Rename to fused_attn_fp16_bf16_max_seqlen_512 for the better readability Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove CUDNN_FRONTEND_UNUSED Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add more annotations to the custom calls Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 28 Apr, 2023 1 commit
-
-
Ming-Xu Huang authored
* Adjust Module Structure. 1. Collect Flax related modules to a sub-folder, flax. 2. Add a function to unify scale_init for zero-centered-gamma LN. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Make changes be compatible to previous versions. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adapt jax/examples to the new module structure. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Update jax/docs and Add deprecated warning. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Update README Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding deprecated_wrapper Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding deprecated warning to flax modules which imported via transformer_engine.jax Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix CI errors and update docs. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Removing unnecessary deprecated warning in docs. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Implementing __iter__ to DeprecatedEnum. Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 20 Apr, 2023 1 commit
-
-
Ming-Xu Huang authored
* Allow update_collections and update_fp8_metas to return both Dict and FrozenDict. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix the wrong shape issue of bias when fused QKV or KV. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Reuse tuplized features for bias creating. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Replace get_args to be more readable. Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com>
-
- 19 Apr, 2023 1 commit
-
-
Kirthi Shankar Sivamani authored
* Port initial changes Co-authored-by:
Sangkug Lym <slym@nvidia.com> Co-authored-by:
Vasudevan Rengasamy <vrengasamy@nvidia.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * readd FA include for PyTorch Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Re-enable sm_70 + cleanup Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * LICENSE, cleanup header Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * 5k -> 173 errors Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * license and fixes in userbuffers-host Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * next round fixes Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * final cpp cleanup Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * pylinting Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * fix from linting Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Turn off default async amax reduction (#148) Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * remove unused code path Signed-off-by:
Sangkug Lym <slym@nvidia.com> * cleanup Macros Signed-off-by:
Sangkug Lym <slym@nvidia.com> * fix conflict resolution bug Signed-off-by:
Sangkug Lym <slym@nvidia.com> * Fix gencode flags in setup (#145) * Fix gencode flags based on cuda version Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * review suggestions Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * revert append_nvcc_threads change Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Change overlap config dict error message Signed-off-by:
Sangkug Lym <slym@nvidia.com> * simplify ub initialization Signed-off-by:
Sangkug Lym <slym@nvidia.com> * lint Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * fix sanity imports Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * cpplint Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * fix TensorFlow build Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * fix TE macros in public header Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * fix lint Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * More fixes Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * compiles with and w/o MPI Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * fixes for python side annotations for conditional compile Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * link gdrAPI only when MPI found Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * fix comments for dummy var Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix linking Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Review comments Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * load MPI before TE Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Add Py side argument checks Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * remove unused code and catch silent failures Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix cpp tests Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * fix find_lib path for tests Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
Sangkug Lym <slym@nvidia.com> Co-authored-by:
Sangkug Lym <slym@nvidia.com> Co-authored-by:
Vasudevan Rengasamy <vrengasamy@nvidia.com>
-
- 13 Apr, 2023 1 commit
-
-
zlsh80826 authored
* Add zero_center_gamma/functional pass Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add zero_centered_gamma for fp8_ln_mlp Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add zero_centered_gamma to modules Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add zero_centered_gamma to TransformerLayer Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refactored code style for improved readability and consistency Signed-off-by:
Reese Wang <rewang@nvidia.com> * Docs enhancement for zero_centered_gamma Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add escape for line break and remove some bad if conditions Signed-off-by:
Reese Wang <rewang@nvidia.com> * Revise scale_init docs Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com>
-
- 07 Apr, 2023 1 commit
-
-
Ming-Xu Huang authored
* Rename enable_fp8 to is_fp8_enabled. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding an API to get an instance of DelayedScaling which is set via fp8_autocast. Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com>
-
- 29 Mar, 2023 1 commit
-
-
Ming-Xu Huang authored
* Support transpose_bs when decoded=True Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix Bugs, 1. Fix missing dropout_dims in LayerNormMLP. 2. Fix broadcast issues in decoded. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix wrong masks in decoded. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fixed wrong assert condition in TransformerLayer Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix amax is not set as 0 in each step. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Enhance rules conflict checking and docs. Signed-off-by:
Ming Huang <mingh@nvidia.com> * fix code formatting. Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> Signed-off-by:
Ming Huang <mingh@nvidia.com>
-
- 28 Mar, 2023 1 commit
-
-
Jeng Bai-Cheng authored
* refactor JAX examples Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix doc-string Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add dp example Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix params_axes_pspec Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Add model parallel example and refactor Update readme Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * align code and readme Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * update verification Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add mask Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * num_gpu is configurable Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * update readme Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * update readme Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * solvepylint issue Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * ignore markdown and txt file from license check Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update README.md Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add flax into requirements.txt Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> --------- Signed-off-by:
Ryan Jeng <rjeng@nvidia.com>
-
- 16 Mar, 2023 1 commit
-
-
Ming-Xu Huang authored
* Adding JAX to README.rst Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Refine README.rst as the suggestion from review. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Refine the API doc of extend_logical_axis_rules. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 14 Mar, 2023 1 commit
-
-
Ming-Xu Huang authored
* Updated TE/JAX docs Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding TE/JAX docs' rst files Signed-off-by:
Ming Huang <mingh@nvidia.com> * Set DType as pybind11::module_local() to avoid generic_type errors. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Updating license and exporting more modules Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adopting autoapi and removing enum_tools. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix typo Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Make jax.rst be style consistent. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fixing doc statements as the suggestion from review. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fixing doc statements as the suggestion from code review. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Update the description of Softmax Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Removed categories in catalog as PyTorch Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 10 Mar, 2023 1 commit
-
-
Ming-Xu Huang authored
Signed-off-by:Ming-Xu Huang <mingh@nvidia.com>
-
- 09 Mar, 2023 1 commit
-
-
Jeng Bai-Cheng authored
* add transformer module , unittests and examples Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update tests/jax/test_sharding.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update transformer_engine/jax/transformer.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * remove pylint: disable=line-too-long Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove pylint: disable=too-many-func-args Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Fix the wrong broadcasting dim to dropout masks when enable transpose_bs. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Enable 2xACC for WGRAD and DGRAD by default Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * rename LayerNormMlpBlock as LayerNormMLP Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor to avoid line-too-long Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * rename amax_history_size to amax_history_len Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * align dropout mask to TE/PyTorch as default Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * enlarge atol for decoder unittests Two decoder unittests can pass in old JAX container(e.g., 23.02) but can't in latest container (devel). 1. The actual(-0.020264) and desired(-0.020386) are very close. 2. The TE kernels are not changed, the diff should come from new codegen behavior of XLA. Thus, it is a common floating-point accumulated error. Enlarge atol to avoid unittest failures. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Adding Amax History Support 1. hide amax update in custom_vjp 2. replace amax indexing with roll(using circular buffer) Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * move kernel_init to __post_init__ Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor encoder examples Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update transformer_engine/jax/fp8.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update transformer_engine/jax/fp8.py Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * remove envvar regarding 2xACC Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove unused import Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> --------- Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by:
Ming-Xu Huang <mingh@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 24 Feb, 2023 1 commit
-
-
Jeng Bai-Cheng authored
* add building workflow for jax modules Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * replace bit_cast with reinterpret_cast Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add nvtx to cmake check list Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor layernorm fwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor rmsnorm fwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor layernorm_bwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * set pytorch as default in setup.py Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * rename extension from *.cc to *.cpp cpplint cannot recognize *.cc file, so rename the extension Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor style, to align TE/PyTorch Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add pybinding, unittest and qa Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix license Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * disable c-extension-no-member and no-name-in-module Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add dataclass avoid pylint error Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update transformer_engine/__init__.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update tests/jax/test_custom_call_shape.py fix typo Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update tests/jax/test_custom_call_shape.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * add building workflow for jax modules Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * replace bit_cast with reinterpret_cast Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add nvtx to cmake check list Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor layernorm fwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor rmsnorm fwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor layernorm_bwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * set pytorch as default in setup.py Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * rename extension from *.cc to *.cpp cpplint cannot recognize *.cc file, so rename the extension Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor style, to align TE/PyTorch Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add pybinding, unittest and qa Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix license Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * disable c-extension-no-member and no-name-in-module Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add dataclass avoid pylint error Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update transformer_engine/__init__.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update tests/jax/test_custom_call_shape.py fix typo Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update tests/jax/test_custom_call_shape.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * fix conflict due to PR62 Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix c-extension-no-member and no-name-in-module 1. add transformer_engine_jax into extension-pkg-whitelist 2. convert pylintrc from CRLF to LF format Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update setup.py Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * remove pylint:disable and refactor import order Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> --------- Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-