- 12 Jan, 2024 1 commit
-
-
Ming-Xu Huang authored
* Adding Cast custom call Signed-off-by:
Ming Huang <mingh@nvidia.com> * Applying cast to the kernel of layernorm_fp8_dot Signed-off-by:
Ming Huang <mingh@nvidia.com> * Applying native cast to the kernel of fp8_dot. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Apply Cast and native cast to layernorm_geglu_fp8_dot Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix the bug to enable layernorm_geglu_fp8_dot in LayernormMlp Signed-off-by:
Ming Huang <mingh@nvidia.com> * Modifiied code with the review feedback. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding 2xACC control to FP8 GEMMs. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Set precision as an static arg Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com>
-
- 03 Jan, 2024 1 commit
-
-
Przemyslaw Tredak authored
Signed-off-by:Przemek Tredak <ptredak@nvidia.com>
-
- 14 Nov, 2023 1 commit
-
-
Ming-Xu Huang authored
* Refactor sharding.py for the further custom_partitioning migration Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating both FWD and BWD of LayerNorm/RMSNorm from xmap to custom_partitioning. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating both FWD and BWD of all kinds of softmax from xmap to custom_partitioning. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix the wrong order of parameters to LN/RMSN bwd in ln_mlp_fp8. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * WAR to LN/RMSN_fp8 before migrating to CP. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix the wrong order of parameters of bwd of LN/RMSN_fp8. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Following review feedback to modify Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Force the hidden dim in Norm ops to no sharding and add warning msg. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Reuse fwd_rule in VJP functions Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating both FWD and BWD of self-fused-attn from xmap to custom_partitioning. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating both FWD and BWD of cross-fused-attn from xmap to custom_partitioning. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * add gelu and dgelu. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Reuse fwd_rule in VJP functions for attentions Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Apply native FP8 Dtypes to fp8.py Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating cast_and_transpose from xmap to custom_partitioning Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating transpose from xmap to custom_partitioning Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Apply XLA pattern match to perform FP8 GEMM. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * migrate layernorm_fp8 to custom_partitioning. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Unify code style Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Extend supported of Transpose with FP8 Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Implementing layernorm_fp8_dot based on migrated custom calls. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Renaming variables and publish NVTE_FP8_COLLECTION_NAME Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Replace Q/DQ custom calls with native XLA implementations Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * migrate gelu_fp to custom_partitioning. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Miner fix Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Support custom calls with mutli-dims Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Support gerneral dot indices in _fp8_dot_impl Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Implementing layernrom_geglu_fp8_mlp Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Remove GEMM custom calls Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Remove xmap related code Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix typo and add query-function to FP8MetaPackage Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix some bugs of custom calls Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix CT's bugs Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Update UTs/eaxmaples to adapt to the API changes. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Unify kernel initilization in MLP. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Modifing with code review's feedback Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Update README and Add deprecating warning to *ShardingType Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Canonicalize the dtype Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding assertion for non-supported batch dims. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding doc/examples to _multidim_transpose Signed-off-by:
Ming Huang <mingh@nvidia.com> * Set FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Set FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Apply dtype-based rtol/atol to UTs Signed-off-by:
Ming Huang <mingh@nvidia.com> * Deprecate QKV_INTERLEAVED enum Signed-off-by:
Ming Huang <mingh@nvidia.com> * Skip test_distributed_custom_ops.py Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix the wrong sharding of bias in SelfAttn Signed-off-by:
Ming Huang <mingh@nvidia.com> * WAR to fix the wrong cu_seqlen of MHA when DP/FSDP enabled Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding distributed ops unit-tests Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding license to test_distributed_* Signed-off-by:
Ming Huang <mingh@nvidia.com> * Follow review feedback to modify Signed-off-by:
Ming Huang <mingh@nvidia.com> * Use total bytes involved in collective ops as criteria. Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> Co-authored-by:
Donglin Yang <dongliny@nvidia.com>
-
- 13 Nov, 2023 1 commit
-
-
zlsh80826 authored
[C/JAX] Support more mask types for the arbitrary seqlen kernels and minor changes of JAX bias (#469) * Move bias to float32 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enable varlen Signed-off-by:
Reese Wang <rewang@nvidia.com> * Increase neg infinity abs values Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enable varlen tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove unnecessary code Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix lint Signed-off-by:
Reese Wang <rewang@nvidia.com> * Support variable sequence length after cuDNN 8.9.6 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use unique_ptr instead of shared_ptr Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add a new mask type: PADDING_CAUSAL_MASK Signed-off-by:
Reese Wang <rewang@nvidia.com> * Support flash padding mask after 8.9.6 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the Max512 handling for causal masking and add the related tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update the fused attn support lists Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove padding_aware from the caching Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix libtransformer.so issue Signed-off-by:
Reese Wang <rewang@nvidia.com> * Reduce the pad ratio tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix a bug with cuDNN 8.9.5 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Release backend resource after the module level unit test Signed-off-by:
Reese Wang <rewang@nvidia.com> * Clean the jax live arrays before running the unit tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix too-few-public-methods lint Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 08 Nov, 2023 1 commit
-
-
Tim Moon authored
* Use FP8 tolerances in JAX FP8 tests Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Programmatically compute expected floating point error Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Loosen tolerance for MNIST test Signed-off-by:
Tim Moon <tmoon@nvidia.com> --------- Signed-off-by:
Tim Moon <tmoon@nvidia.com>
-
- 12 Oct, 2023 1 commit
-
-
Tim Moon authored
* Debug PyTorch and Paddle tests on Ada Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Only run Paddle layer tests with cuDNN fMHA on supported archs Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Debug PyTorch fMHA tests Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Reduce JAX FP8 GEMM sizes Avoid split-k kernels on Ada. Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Disable JAX fused self-attention test on Ada Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Run supported fused attention tests on Ada Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Run supported fused attention JAX tests on Ada Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Enable Paddle fused attention on Ada Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Update reference scale calculation in TensorFlow test Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Restore backend support to reference FP8 attention impl in PyT test Review suggestion from @cyanguwa Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Fix merge conflicts Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Debug Paddle tests on Ada Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Loosen tolerances for Paddle attention tests Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Assume causal mask implies equal seqlens in Paddle attention tests Signed-off-by:
Tim Moon <tmoon@nvidia.com> --------- Signed-off-by:
Tim Moon <tmoon@nvidia.com>
-
- 09 May, 2023 1 commit
-
-
Jeng Bai-Cheng authored
* add mp example Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * update doc-string Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * better FP8 checker Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * update readme Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * replace te.* with te.flax* to remove deprecated warning Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove nouse os.environ Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove nouse Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix typo Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/README.md Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/README.md Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/README.md Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/README.md Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/test_multiprocessing_encoder.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove cuda-python Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * adjust readme Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update examples/jax/encoder/README.md Signed-off-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix cpp lint fix issue of "Could not find a newline character at the end of the file." Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix AssertionError: 1 GPU per process Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * replace tfds with datasets The Flax application crash if it use TensorFlow Dataset (tfds) in NVIDIA JAX container. The tfds is very useful for downloading well-knwon dataset (e.g., MNIST, GLUE) and commonly used by TF/JAX community. However, it seems like that it is NOT compatible with NVIDIA TensorFlow in NVIDIA JAX container and somehow affects JAX. It triggers random errors at JAX initialization depending on different versions, and make CI unstable. Thus, this commit replaces tfds with "huggingface datasets" to download needed datasets. See "nvbugs 4039266" for more details. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix input sharding Unlike SPMD mode, in multiprocessing mode, the input tensor must be sharded manually. Using DP=4, TP=2 as an example, the device mesh looks like: mesh.device_ids = [[0, 1], [2, 3], [4, 5], [6, 7]] Assume that the process ID is mapped to GPU ID. The process 0 and process 1 are grouped for model parallelism, process 2 and process 3 are grouped together too, and so on. The process 0 and process 1 need to share the same micro-batch in the training step, process 0 and process 2, 4, and 6 have different micro-batch. Thus, `shard_array_wrapper` partitions inputs to 4 parts (and setup needed arguments for jax.make_array_from_single_device_arrays). The process 0 and process 1 take the first quarter, process 2 and process 3 take the second quarter, and so on. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor UT for multiprocess example Use Python `multiprocessing` to test the multiprocessing example, if the system has multiple GPU. 1 GPU per process. Because `jax.distributed.initialize` must be called before any other JAX or Flax API, GPU info cannot be queried by calling jax.local_devices() in TestEncoder. Thus, `unittest_query_gpu()` forks another process to query number of GPUs and FP8 capability. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove nouse arg `--num-gpu` JAX doesn't have an API to setup number of GPU used in SPMD mode. The only way is to use `CUDA_VISIBLE_DEVICES` for now. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix typo Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix ut Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * simplify the mask setting Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * increase batch-size for multigpu example The batch-size 64 is too small to be partitioned for 8xH100. If batch-size is 64, the GEMM shape is 256x8192x8 per GPU. The 8 is too small for FP8 GEMM kernel, and cuBLASLt will throw "Failed to query heuristics". Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix downloading mnist error To download MNIST via `huggingface datasets`, it requires Pillow. Otherwise, it throws `An error occurred while generating the dataset` Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> --------- Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Signed-off-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com>
-
- 13 Apr, 2023 1 commit
-
-
zlsh80826 authored
* Add zero_center_gamma/functional pass Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add zero_centered_gamma for fp8_ln_mlp Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add zero_centered_gamma to modules Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add zero_centered_gamma to TransformerLayer Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refactored code style for improved readability and consistency Signed-off-by:
Reese Wang <rewang@nvidia.com> * Docs enhancement for zero_centered_gamma Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add escape for line break and remove some bad if conditions Signed-off-by:
Reese Wang <rewang@nvidia.com> * Revise scale_init docs Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com>
-
- 09 Mar, 2023 1 commit
-
-
Jeng Bai-Cheng authored
* add transformer module , unittests and examples Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update tests/jax/test_sharding.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update transformer_engine/jax/transformer.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * remove pylint: disable=line-too-long Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove pylint: disable=too-many-func-args Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Fix the wrong broadcasting dim to dropout masks when enable transpose_bs. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Enable 2xACC for WGRAD and DGRAD by default Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * rename LayerNormMlpBlock as LayerNormMLP Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor to avoid line-too-long Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * rename amax_history_size to amax_history_len Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * align dropout mask to TE/PyTorch as default Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * enlarge atol for decoder unittests Two decoder unittests can pass in old JAX container(e.g., 23.02) but can't in latest container (devel). 1. The actual(-0.020264) and desired(-0.020386) are very close. 2. The TE kernels are not changed, the diff should come from new codegen behavior of XLA. Thus, it is a common floating-point accumulated error. Enlarge atol to avoid unittest failures. Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Adding Amax History Support 1. hide amax update in custom_vjp 2. replace amax indexing with roll(using circular buffer) Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * move kernel_init to __post_init__ Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor encoder examples Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update transformer_engine/jax/fp8.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update transformer_engine/jax/fp8.py Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * remove envvar regarding 2xACC Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * remove unused import Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> --------- Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by:
Ming-Xu Huang <mingh@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 24 Feb, 2023 1 commit
-
-
Jeng Bai-Cheng authored
* add building workflow for jax modules Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * replace bit_cast with reinterpret_cast Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add nvtx to cmake check list Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor layernorm fwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor rmsnorm fwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor layernorm_bwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * set pytorch as default in setup.py Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * rename extension from *.cc to *.cpp cpplint cannot recognize *.cc file, so rename the extension Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor style, to align TE/PyTorch Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add pybinding, unittest and qa Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix license Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * disable c-extension-no-member and no-name-in-module Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add dataclass avoid pylint error Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update transformer_engine/__init__.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update tests/jax/test_custom_call_shape.py fix typo Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update tests/jax/test_custom_call_shape.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * add building workflow for jax modules Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * replace bit_cast with reinterpret_cast Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add nvtx to cmake check list Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor layernorm fwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor rmsnorm fwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor layernorm_bwd Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * set pytorch as default in setup.py Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * rename extension from *.cc to *.cpp cpplint cannot recognize *.cc file, so rename the extension Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * refactor style, to align TE/PyTorch Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add pybinding, unittest and qa Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix license Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * disable c-extension-no-member and no-name-in-module Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * add dataclass avoid pylint error Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update transformer_engine/__init__.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update tests/jax/test_custom_call_shape.py fix typo Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * Update tests/jax/test_custom_call_shape.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * fix conflict due to PR62 Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * fix c-extension-no-member and no-name-in-module 1. add transformer_engine_jax into extension-pkg-whitelist 2. convert pylintrc from CRLF to LF format Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> * Update setup.py Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> * remove pylint:disable and refactor import order Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> --------- Signed-off-by:
Ryan Jeng <rjeng@nvidia.com> Signed-off-by:
Jeng Bai-Cheng <jeng1220@users.noreply.github.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-