- 08 Aug, 2023 1 commit
-
-
Przemyslaw Tredak authored
Fix for the RMSNorm tests/doc/ONNX export to match the actual implementation Signed-off-by:Przemek Tredak <ptredak@nvidia.com>
-
- 07 Aug, 2023 1 commit
-
-
zlsh80826 authored
* Fix flash attention dropout probability with inference Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add output as the fused attention ctx tensor Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add rng_state as the fused attention ctx tensors Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add flash attention supported lengths to the fused attention Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refactor attention primitive to reuse abstract shaped array Signed-off-by:
Reese Wang <rewang@nvidia.com> * Detect backend type to allocate appropriate ctx size Signed-off-by:
Reese Wang <rewang@nvidia.com> * Skip dropout correctness instead of return success Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use cudaMemsetAsync and enhance the error handling Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add flash attention kernel elts_per_thread update Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove redundant max 512 suffix Signed-off-by:
Reese Wang <rewang@nvidia.com> * Keep only DType and remove NVTEDType from python Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix a float32_attention_logits bugs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Re-calculate workspace size for self attention Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance bias/dbias shape guard Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the seed/rng_state checker Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use jax.core.ShapedArray as jax.abstract_arrays is deprecated Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the unittest docs Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 02 Aug, 2023 2 commits
-
-
Kirthi Shankar Sivamani authored
Signed-off-by:Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
Tian Zheng authored
Refactor fp8 state Signed-off-by:Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
-
- 01 Aug, 2023 1 commit
-
-
Tian Zheng authored
* Add FP8 support - Add FP8 recipe - Add FP8 path for nn layers - Add MNIST FP8 example Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> * Update README Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> * Fix LayerNormMLP FP8 backward Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> * Fix FP8 training in float32 accumulation Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> * Fix FP8 checkpointing for non forward execution cases (same as #323) Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> * Refactors and improvements for better code stype, readability and organization Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> * Remove unnecassary pylint override Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> --------- Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
-
- 29 Jul, 2023 1 commit
-
-
cyanguwa authored
* add support for multi-query/grouped-query attention Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert to flash-attn 1.0.6 and build 2.0.0.post1 manually in CI Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add keyword name for DPA input Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fused attn tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix skipif for pytest Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Update transformer_engine/pytorch/attention.py Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Update tests/pytorch/test_fused_attn.py Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix TP and SP case Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * add skipifs for pytest Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove higher limit for flash-attn version Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 28 Jul, 2023 1 commit
-
-
Kirthi Shankar Sivamani authored
Minor bug fix Signed-off-by:Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 27 Jul, 2023 1 commit
-
-
Przemyslaw Tredak authored
* Exposing RMSNorm in pyTorch extensions Signed-off-by:
Przemek Tredak <ptredak@nvidia.com> * First pass at the Python API Signed-off-by:
Przemek Tredak <ptredak@nvidia.com> * Small fixes Signed-off-by:
Przemek Tredak <ptredak@nvidia.com> * Added numerics tests and fixed issues Signed-off-by:
Przemek Tredak <ptredak@nvidia.com> * Lint fixes Signed-off-by:
Przemek Tredak <ptredak@nvidia.com> * Added RMSNorm to LayerNormMLP Signed-off-by:
Przemek Tredak <ptredak@nvidia.com> * Added ONNX export and tests for RMSNorm Signed-off-by:
Przemek Tredak <ptredak@nvidia.com> * Fix python lint Signed-off-by:
Przemek Tredak <ptredak@nvidia.com> * Fix BERT case Signed-off-by:
Przemek Tredak <ptredak@nvidia.com> * Added normalization option to the TransformerLayer Added tests Fixed test failures Signed-off-by:
Przemek Tredak <ptredak@nvidia.com> * Fix documentation Co-authored-by:
Przemyslaw Tredak <ptrendx@gmail.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix kwarg bug Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix IMA and invalid type error Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Increase RMSNorm threshold for bf16 case Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix ONNX tests Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Przemek Tredak <ptredak@nvidia.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 26 Jul, 2023 1 commit
-
-
Kirthi Shankar Sivamani authored
Signed-off-by:Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 20 Jul, 2023 2 commits
-
-
Shijie authored
* add flash attn tests Signed-off-by:
Shijie Wang <jaywan@nvidia.com> * update flash attn Signed-off-by:
Shijie Wang <jaywan@nvidia.com> * fix random seed Signed-off-by:
Shijie Wang <jaywan@nvidia.com> --------- Signed-off-by:
Shijie Wang <jaywan@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
Kirthi Shankar Sivamani authored
Signed-off-by:Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 19 Jul, 2023 4 commits
-
-
Tian Zheng authored
* Add Linear layer (FP16) Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> - Add BF16 training example - Add fp8_autocast (only supports non-fp8 for now) Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> * Remove FP8 stuff Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> * Simplify Linear layer forward Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> * Add LayerNorm layer (BF16) Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> * Add LayerNormLinear layer Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> * Store weights in BF16 Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> * Add LayerNormMLP layer Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> * Add BF16 MNIST example Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> * Remove in-place cast for compatibility with Paddle AMP mechanism Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> * README correction Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> * Add Paddle op as a backend option Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> * Fix code format Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> * Fix dtype change between iterations Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> * Minor fixes Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> * Move forward function out of base layer Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> * Use Paddle nvtx bindings Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> --------- Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
-
Kirthi Shankar Sivamani authored
* FA v2.0 support Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * fix typo Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
Kirthi Shankar Sivamani authored
Signed-off-by:Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
Kirthi Shankar Sivamani authored
* FA does not support head_dim > 64 on Ada Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Add cc8.7 to no FA list Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 18 Jul, 2023 2 commits
-
-
zlsh80826 authored
* Fully remove attn_type and set self_attn_mask_type default to 'causal' Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix tests with new arguments Signed-off-by:
Reese Wang <rewang@nvidia.com> * Explicit self_attn_mask_type for examples Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update transformer_engine/jax/flax/transformer.py Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
zlsh80826 <rewang@nvidia.com> * Update transformer_engine/jax/flax/transformer.py Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
zlsh80826 <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Signed-off-by:
zlsh80826 <rewang@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
Kirthi Shankar Sivamani authored
* layer number bug fixes Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 15 Jul, 2023 1 commit
-
-
Tim Moon authored
* Disable TorchDynamo optimizations in PyTorch modules Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Add test for Torch Dynamo Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Add torch.dynamo test to qa Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Skip torch.compile test for <v2.0 Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Tim Moon <tmoon@nvidia.com> Signed-off-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 14 Jul, 2023 3 commits
-
-
Kirthi Shankar Sivamani authored
Bug fix for checkpointing Signed-off-by:Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
cyanguwa authored
* Fix bprop for cuDNN 8.9.3 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Update cuDNN version requirement to 8.9.3 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * debug paddle CI Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * debug paddle CI; force LD_LIBRARY Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * debug paddle CI; force LD_LIBRARY to /opt Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove debug info for paddle Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * change cudnn requirement to 8.9.1 for v1 and 8.9.0 for v2; add batch size 32 for unit test; add LD library path for paddle tests temporarily Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove printf line in fused_attn.cpp Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add batch size 32 for unit test Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update cudnn-frontend to 0.9.2 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove temporary LD library path used for testing pre-released cudnn 8.9.3 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
-
Kirthi Shankar Sivamani authored
* Deprecate unused APIs Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * review comments Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Review Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 13 Jul, 2023 4 commits
-
-
Kirthi Shankar Sivamani authored
Signed-off-by:Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
Neta Zmora authored
* Fix FP32 LayerNorm ONNX export When running inference use a fwd method that is registered with torchscript. Signed-off-by:
Neta Zmora <nzmora@nvidia.com> * Bug fix Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Neta Zmora <nzmora@nvidia.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
Kirthi Shankar Sivamani authored
Remove extra buffers Signed-off-by:Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
Kirthi Shankar Sivamani authored
* Better dimension assert for FP8 Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * line Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 07 Jul, 2023 1 commit
-
-
Ming-Xu Huang authored
Signed-off-by:Ming Huang <mingh@nvidia.com>
-
- 06 Jul, 2023 1 commit
-
-
Shijie authored
* add more ops Signed-off-by:
Shijie Wang <jaywan@nvidia.com> * add skipif Signed-off-by:
Shijie Wang <jaywan@nvidia.com> * fix bug Signed-off-by:
Shijie Wang <jaywan@nvidia.com> * minor change Signed-off-by:
Shijie Wang <jaywan@nvidia.com> * minor change on coding style Signed-off-by:
Shijie Wang <jaywan@nvidia.com> --------- Signed-off-by:
Shijie Wang <jaywan@nvidia.com>
-
- 01 Jul, 2023 1 commit
-
-
Tim Moon authored
Signed-off-by:
Tim Moon <tmoon@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 30 Jun, 2023 1 commit
-
-
Tejaswin Parthasarathy authored
fix : TE virtuelenv discovery Signed-off-by:tejaswinp <tejaswinp@nvidia.com>
-
- 26 Jun, 2023 2 commits
-
-
Kirthi Shankar Sivamani authored
* Get default dtype from pytorch Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Review comments Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
Kirthi Shankar Sivamani authored
Signed-off-by:Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 23 Jun, 2023 1 commit
-
-
Neta Zmora authored
* Fix ONNX export of layer_norm ONNX has a spec bug: ConstantOfShape supports all dtypes except for BF16. To WAR we use dtype FP32 and then cast to BF16. Will also issue a PR to the ONNX sig committee to change the spec in opset 20. Signed-off-by:
Neta Zmora <nzmora@nvidia.com> * fix lint Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Neta Zmora <nzmora@nvidia.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 22 Jun, 2023 2 commits
-
-
Tian Zheng authored
* Add cast_transpose Add gelu, gelu_fp8 Add cast_transpose_bgrad_dgelu Add layernorm_fwd and layernorm_fwd_fp8 Add layernorm_bwd Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> * Fix missing header Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> --------- Signed-off-by:
Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
cyanguwa authored
* add long sequence support and unify three backends for fused attention Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update cudnn-frontend to v0.9.1 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * replace cpu_float2half_rn with __float2half_rn Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix backend selection and NVTEDType Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix ci Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * make cudnn plan caches thread_local Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix CI Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * replace cuDNN throw with NVTE_CHECK Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix replacement of cuDNN throw with NVTE_CHECK Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * force dropout probablity to 0 in inference mode Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * change negInfinity to be consistent with m512 fused attn Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove float2half conversion for scale_dropout Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add back runtime api for sm detection Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add gemm3 to enums FP8Fwd/BwdTensors Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * change dropout from no to yes for fmha_v1 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove output_rng_state in m512 kernels Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix elts_per_thread calculation in kvpacked fwd Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove dropout=0.0 restriction for m512 fused attn Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove output_rng_state completely from m512 kernels Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 21 Jun, 2023 1 commit
-
-
asfiyab-nvidia authored
Signed-off-by:Asfiya Baig <asfiyab@nvidia.com>
-
- 20 Jun, 2023 3 commits
-
-
Kirthi Shankar Sivamani authored
Signed-off-by:Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
zlsh80826 authored
* Enable fused attention dropout Signed-off-by:
Reese Wang <rewang@nvidia.com> * Cast the uint32 key/counter to int64 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update dropout support in fused attention docs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Revise devPtrCuSeqlen* to align the naming Signed-off-by:
Reese Wang <rewang@nvidia.com> * Support different Jax PRNG impls Signed-off-by:
Reese Wang <rewang@nvidia.com> * Revert CastAsync since it is not used Signed-off-by:
Reese Wang <rewang@nvidia.com> * Implement is_training for 16-bit fused attn Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add fused attn with dropout sanity unit tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the comments readability and rng_state checker Signed-off-by:
Reese Wang <rewang@nvidia.com> * Change the attention dropout shape to align other frameworks Signed-off-by:
Reese Wang <rewang@nvidia.com> * Make encoder tests deterministic Signed-off-by:
Reese Wang <rewang@nvidia.com> * Change the default seed for the jax encoder tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Maintain offset in TE Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the resource safety Signed-off-by:
Reese Wang <rewang@nvidia.com> * Revert rng_state type to allow only i64 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Handle the corner case for elts_per_threads calculation Signed-off-by:
Reese Wang <rewang@nvidia.com> * Populate rng state by kernels Signed-off-by:
Reese Wang <rewang@nvidia.com> * Rename rng_state as seed in cpp_extensions Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update the attention dropout comment Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
zlsh80826 authored
* Add self_attn_mask_type and replace attn_type Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refine the keyword style for the better readability Signed-off-by:
Reese Wang <rewang@nvidia.com> * Replace attn_type with attn_mask_type in praxis transformer Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix typos Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 16 Jun, 2023 2 commits
-
-
Neta Zmora authored
* Fix softmax ONNX export * BF16 is validated using "fake i/o": ie. instead of using BF16 as input/output, use FP32 input/output and convert to/from BF16 in the forward method. * Wrap softmax symbolic functions with conversion to/from FP32 to produce the same semantics as TE's softmax (compute is performed at FP32 precision regardless of input/output data type). Signed-off-by:
Neta Zmora <nzmora@nvidia.com> * ONNX export Code refactoring Share function compute_in_fp32 between softmax.py (softmax symbolic functions) and te_onnx_extensions.py (the rest of the symbolic functions). Signed-off-by:
Neta Zmora <nzmora@nvidia.com> * Fix exports Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * lint Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Neta Zmora <nzmora@nvidia.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
Kirthi Shankar Sivamani authored
* ReLU ONNX export Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Add GLU variants Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * linter check Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Export reglu, geglu, swiglu Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Review comments Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-