- 27 Feb, 2024 1 commit
-
-
Ming-Xu Huang authored
Support various implementations of RoPE and fix a coordinate representation bug Signed-off-by:Ming Huang <mingh@nvidia.com>
-
- 22 Feb, 2024 1 commit
-
-
Reese Wang authored
* Refine MHA API Signed-off-by:
Reese Wang <rewang@nvidia.com> * Reuse func from the flax Signed-off-by:
Reese Wang <rewang@nvidia.com> * DPA draft Signed-off-by:
Reese Wang <rewang@nvidia.com> * qkv packed draft Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix test_layer with fused attn Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add attn_bias_type and enhance a few code flow Signed-off-by:
Reese Wang <rewang@nvidia.com> * Move scale_factor from __call__ to init Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the docs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add DPA public API and tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refine docs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refine docs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix conflict Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add qkv separate fused attn Signed-off-by:
Reese Wang <rewang@nvidia.com> * Apply BSHD_BSHD_BSHD format Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove debug log Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add fused attention layer tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add NVTE_FUSED_ATTN docs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fine-grained fused attn settings Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove the default value of num_attetnion_head and head_dim Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add teardown for fused attn env Signed-off-by:
Reese Wang <rewang@nvidia.com> * Unify the Optional notation Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix Pre/Post scale bias comments Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add no_mask tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add checkpoint_name for fused attn Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix the fused attn batcher Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 02 Feb, 2024 1 commit
-
-
Ming-Xu Huang authored
* Adding support of sequence parallelism Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding RoPE Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix wrong batch_logical_axes Signed-off-by:
Ming Huang <mingh@nvidia.com> * Rnaming FSDP outer env var Signed-off-by:
Ming Huang <mingh@nvidia.com> * Poring RoPE to Praxis layers. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Porting GeLU + [FP8 Cast]. Signed-off-by:
Ming Huang <mingh@nvidia.com> * WAR to make XLA successfully match FP8 GEMM on FFN1 with GeLU. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Allowing arbitrary dimension of NVShape for the workspace allocation Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding checkpoint_name to fused functions of mlp.py to get better perf with nn.scan. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Modify with review feedback. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix bugs Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix typo. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fixed for lint Signed-off-by:
Ming Huang <mingh@nvidia.com> * Follow review feedback to modify code. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix typo. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Port SP to Praxis Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix an issue when enabling both GQA and RoPE. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Update docs Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com>
-
- 01 Feb, 2024 1 commit
-
-
zlsh80826 authored
* Fix unfused GQA perf Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove WAR for Check failed: reduction_kind.has_value() Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 29 Jan, 2024 1 commit
-
-
Alp Dener authored
* Removed cudaMalloc/WorkspaceManager in JAX csrc. JAX custom ops now request buffers from XLA for their workspace tensors. Signed-off-by:
Alp Dener <adener@nvidia.com> * removed unused GEMM C++ API in TE-JAX Signed-off-by:
Alp Dener <adener@nvidia.com> * fixed typo in layernorm_geglu_fp8_mlp and removed unnecessary shape reductions in primitives Signed-off-by:
Alp Dener <adener@nvidia.com> * fixed import order for linting Signed-off-by:
Alp Dener <adener@nvidia.com> * fixed custom op errors due to incorrect static arg nums in JAX jit Signed-off-by:
Alp Dener <adener@nvidia.com> * shifted cudnnSetStream further down the kernel to avoid error when executing dummy kernel call with nullptr stream Signed-off-by:
Alp Dener <adener@nvidia.com> * fixed linting errors for blank lines Signed-off-by:
Alp Dener <adener@nvidia.com> --------- Signed-off-by:
Alp Dener <adener@nvidia.com>
-
- 16 Jan, 2024 1 commit
-
-
zlsh80826 authored
* Support num_gqa_groups arguments Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add GQA support on the JAX bridge code Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix the kv stride of the arbitrary backend Signed-off-by:
Reese Wang <rewang@nvidia.com> * Complete rewrite fused attention tests and add GQA coverage Signed-off-by:
Reese Wang <rewang@nvidia.com> * Support unfused GQA Signed-off-by:
Reese Wang <rewang@nvidia.com> * Calculate seqlen before the primitive for the better perf Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add GQA layer tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Apply code style checks for te_jax Signed-off-by:
Reese Wang <rewang@nvidia.com> * Apply code style checks for tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add num_gqa_groups doc Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refine the qkv_type Signed-off-by:
Reese Wang <rewang@nvidia.com> * Correct the variable naming Signed-off-by:
Reese Wang <rewang@nvidia.com> * Handle Max512 CAUSAL Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add WAR for the latest jax image Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 12 Jan, 2024 1 commit
-
-
Ming-Xu Huang authored
* Adding Cast custom call Signed-off-by:
Ming Huang <mingh@nvidia.com> * Applying cast to the kernel of layernorm_fp8_dot Signed-off-by:
Ming Huang <mingh@nvidia.com> * Applying native cast to the kernel of fp8_dot. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Apply Cast and native cast to layernorm_geglu_fp8_dot Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix the bug to enable layernorm_geglu_fp8_dot in LayernormMlp Signed-off-by:
Ming Huang <mingh@nvidia.com> * Modifiied code with the review feedback. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding 2xACC control to FP8 GEMMs. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Set precision as an static arg Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com>
-
- 03 Jan, 2024 1 commit
-
-
Przemyslaw Tredak authored
Signed-off-by:Przemek Tredak <ptredak@nvidia.com>
-
- 14 Dec, 2023 1 commit
-
-
Alp Dener authored
applied Google-advised fix to register custom op primitives with the device dispatch list Signed-off-by:Alp Dener <adener@nvidia.com>
-
- 07 Dec, 2023 1 commit
-
-
cyanguwa authored
* Integrate cuDNN frontend v1 to fused attention and miscellaneous fixes Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix jax/paddle for unit tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix jax/pytorch lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * simplify stride generation Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix and/or logic in get_backend Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix flag_max512 and test_numerics Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove v.contiguous() since get_qkv_layout covers it Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * skip fp8 tests for sm89 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * further fix jax CI Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix jax CI Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert mask type to comma-separated list Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix last two commits Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * integrate v1/pre-release-5 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * cleanup prerelease5 integration and fix FA2.1 commit Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * force dropout to 0 if not training Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix Jax CI Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * testing bias/alibi and padding+causal; add alibi to unfused DPA Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * set flag_arb to false when non determinism is not allowed Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * followup on prev commit; remove redundant python env var setting Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * WIP: minor tweaks for tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * prepare for tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix determinism logic for fused attn Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add bias to bwd Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix gpt_checkpointing/dpa_accuracy problem Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix some seg fault issues Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add failure notes Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove use of non-deter var for backend selection Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fix for lint and CI Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix workspace size in bwd and uncomment bias test Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix get_alibi and remove check_support Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update tests status Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove workspace_opt from FADescriptor_v1 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * disable arbitrary backend + post scale bias in Jax; waiting on PR 525 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up bhsd order Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * swap bias/rng_state order in aux_ctx_tensor and add bias to aux_ctx_tensor in _qkvpacked/_kvpacked API Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove support for padding_causal + cross for max512 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * change alibi bias to float32 for bias_1_4/5 tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * further clean up tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix thd fwd output shape for FlashAttention and add backend info for DPA Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix definition of workspace limit when dbias is present Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * further tweak DP_WORKSPACE_LIMIT definition Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * disallow alibi+no_mask for sdpa flash and update alibi tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update jax/paddle after PR525 and fix DP_WORKSPACE_LIMIT for dbias Jax tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * disable dbias for non-hopper archs Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix layernorm lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remode unused arg for lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove build dir in setup.py Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * change selection logic to prefer fused attn on sm90 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix distributed jax test Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix h and s order in header Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update to cudnn fe v1 branch Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove manual setting of workopt path due to dbias after v1 update Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix paddle CI Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add post_scale_bias and alibi to sdpa flash support matrix Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix support matrix in header files Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * move headers back to .cu and change seed/offset to int64 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update Megatron commit in L1 test and remove all prints in fused attn test Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix L1 Megatron test Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fp8 arg in L1 Megatron script Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * print only when debug flag is on Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove checkpointing loading to avoid loading other tests results Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by:
cyanguwa <8636796+cyanguwa@users.noreply.github.com>
-
- 04 Dec, 2023 1 commit
-
-
zlsh80826 authored
Add checkpoint_name Signed-off-by:Reese Wang <rewang@nvidia.com>
-
- 01 Dec, 2023 1 commit
-
-
zlsh80826 authored
* Add rng_state output for cross fused attention Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add rng_state and output for the flash attention backward Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add bias for the jax cross attn API Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix a minor bug Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add bias in the backward for the arbitrary fused attn backend Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 30 Nov, 2023 2 commits
-
-
zlsh80826 authored
Support layernorm sm_margin through environment variables Signed-off-by:Reese Wang <rewang@nvidia.com>
-
Ming-Xu Huang authored
Use relative idx to ScaledUpperTriangMaskedSoftmaxFwdPrimitive.abstract to support batching. Signed-off-by:Ming Huang <mingh@nvidia.com>
-
- 20 Nov, 2023 1 commit
-
-
zlsh80826 authored
* Remove assertion for NO_MASK Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix JAX distributed unit tests name Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 14 Nov, 2023 1 commit
-
-
Ming-Xu Huang authored
* Refactor sharding.py for the further custom_partitioning migration Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating both FWD and BWD of LayerNorm/RMSNorm from xmap to custom_partitioning. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating both FWD and BWD of all kinds of softmax from xmap to custom_partitioning. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix the wrong order of parameters to LN/RMSN bwd in ln_mlp_fp8. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * WAR to LN/RMSN_fp8 before migrating to CP. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix the wrong order of parameters of bwd of LN/RMSN_fp8. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Following review feedback to modify Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Force the hidden dim in Norm ops to no sharding and add warning msg. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Reuse fwd_rule in VJP functions Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating both FWD and BWD of self-fused-attn from xmap to custom_partitioning. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating both FWD and BWD of cross-fused-attn from xmap to custom_partitioning. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * add gelu and dgelu. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Reuse fwd_rule in VJP functions for attentions Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Apply native FP8 Dtypes to fp8.py Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating cast_and_transpose from xmap to custom_partitioning Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Migrating transpose from xmap to custom_partitioning Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Apply XLA pattern match to perform FP8 GEMM. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * migrate layernorm_fp8 to custom_partitioning. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Unify code style Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Extend supported of Transpose with FP8 Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Implementing layernorm_fp8_dot based on migrated custom calls. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Renaming variables and publish NVTE_FP8_COLLECTION_NAME Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Replace Q/DQ custom calls with native XLA implementations Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * migrate gelu_fp to custom_partitioning. Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Miner fix Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Support custom calls with mutli-dims Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Support gerneral dot indices in _fp8_dot_impl Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Implementing layernrom_geglu_fp8_mlp Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Remove GEMM custom calls Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Remove xmap related code Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix typo and add query-function to FP8MetaPackage Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix some bugs of custom calls Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Fix CT's bugs Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Update UTs/eaxmaples to adapt to the API changes. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Unify kernel initilization in MLP. Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Modifing with code review's feedback Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Update README and Add deprecating warning to *ShardingType Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> * Canonicalize the dtype Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding assertion for non-supported batch dims. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding doc/examples to _multidim_transpose Signed-off-by:
Ming Huang <mingh@nvidia.com> * Set FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Set FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Apply dtype-based rtol/atol to UTs Signed-off-by:
Ming Huang <mingh@nvidia.com> * Deprecate QKV_INTERLEAVED enum Signed-off-by:
Ming Huang <mingh@nvidia.com> * Skip test_distributed_custom_ops.py Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix the wrong sharding of bias in SelfAttn Signed-off-by:
Ming Huang <mingh@nvidia.com> * WAR to fix the wrong cu_seqlen of MHA when DP/FSDP enabled Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding distributed ops unit-tests Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding license to test_distributed_* Signed-off-by:
Ming Huang <mingh@nvidia.com> * Follow review feedback to modify Signed-off-by:
Ming Huang <mingh@nvidia.com> * Use total bytes involved in collective ops as criteria. Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Ming-Xu Huang <mingh@nvidia.com> Co-authored-by:
Donglin Yang <dongliny@nvidia.com>
-
- 13 Nov, 2023 1 commit
-
-
zlsh80826 authored
[C/JAX] Support more mask types for the arbitrary seqlen kernels and minor changes of JAX bias (#469) * Move bias to float32 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enable varlen Signed-off-by:
Reese Wang <rewang@nvidia.com> * Increase neg infinity abs values Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enable varlen tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove unnecessary code Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix lint Signed-off-by:
Reese Wang <rewang@nvidia.com> * Support variable sequence length after cuDNN 8.9.6 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use unique_ptr instead of shared_ptr Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add a new mask type: PADDING_CAUSAL_MASK Signed-off-by:
Reese Wang <rewang@nvidia.com> * Support flash padding mask after 8.9.6 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the Max512 handling for causal masking and add the related tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update the fused attn support lists Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove padding_aware from the caching Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix libtransformer.so issue Signed-off-by:
Reese Wang <rewang@nvidia.com> * Reduce the pad ratio tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix a bug with cuDNN 8.9.5 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Release backend resource after the module level unit test Signed-off-by:
Reese Wang <rewang@nvidia.com> * Clean the jax live arrays before running the unit tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix too-few-public-methods lint Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 08 Nov, 2023 1 commit
-
-
zlsh80826 authored
* Deprecate QKV_INTERLEAVED use in JAX Signed-off-by:
Reese Wang <rewang@nvidia.com> * Deprecate QKV_INTERLEAVED use in Paddle Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance qkv enum mappings Signed-off-by:
rewang <rewang@nvidia.com> * Fix LD_LIBRARY_PATH issue Signed-off-by:
rewang <rewang@nvidia.com> * Arbitrary seqlen kernels only support self attention currently Signed-off-by:
rewang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Signed-off-by:
rewang <rewang@nvidia.com>
-
- 24 Oct, 2023 1 commit
-
-
Tim Moon authored
* Do not include logging macros in installed C headers Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Debug logging macros Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Debug C++ tests Use Google style for header includes. Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Update CUDA driver macros Incorporating changes from #389. Co-authored-by:
Tim Moon <tmoon@nvidia.com> Co-authored-by:
Jan Bielak <jbielak@nvidia.com> Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Use core error checking macros in PyTorch extensions Hack to get around macro redefinition warning. Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Fix missing arg when getting CUDA driver error string Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Reuse logging header in frameworks Signed-off-by:
Tim Moon <tmoon@nvidia.com> --------- Signed-off-by:
Tim Moon <tmoon@nvidia.com> Co-authored-by:
Jan Bielak <jbielak@nvidia.com>
-
- 20 Oct, 2023 2 commits
-
-
Alp Dener authored
fixed incorrect of extend_fsdp_sharding_meta() in cross_fused_attn() Signed-off-by:Alp Dener <adener@nvidia.com>
-
zlsh80826 authored
canonicalize the dtype for the better user experience Signed-off-by:Reese Wang <rewang@nvidia.com>
-
- 10 Oct, 2023 1 commit
-
-
Kirthi Shankar Sivamani authored
Signed-off-by:Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 06 Oct, 2023 1 commit
-
-
Ming-Xu Huang authored
* [JAX] Enhance Dropout in TransformerLayer. 1. Fixed missing setup of dropout RNG key in TransformerLayer and LayerNormMLP. 2. Allowing seperated dropout rate for FC1's output and other hiddens. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix wrong fp8 scale in _update_fp8_metas_impl Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix typo Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 03 Oct, 2023 1 commit
-
-
Frédéric Bastien authored
Signed-off-by:Frederic Bastien <fbastien@nvidia.com>
-
- 27 Sep, 2023 1 commit
-
-
Kirthi Shankar Sivamani authored
Change deprecation warnings Signed-off-by:Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 23 Sep, 2023 1 commit
-
-
Kirthi Shankar Sivamani authored
* Change scaling factor from E8M0 to E8M23 Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * fix formula Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 05 Sep, 2023 1 commit
-
-
Frédéric Bastien authored
Use the new API when it is available. Signed-off-by:Frederic Bastien <fbastien@nvidia.com>
-
- 30 Aug, 2023 1 commit
-
-
Ming-Xu Huang authored
* [JAX] Fix incorrect sharding when only enable FSDP. Signed-off-by:
Ming Huang <mingh@nvidia.com> * [JAX] Add WAR to memory misaligned issues of LN BWD. Signed-off-by:
Ming Huang <mingh@nvidia.com> * [JAX] Reuse sm_arch for avoiding duplicate code. Signed-off-by:
Ming Huang <mingh@nvidia.com> * [JAX] Support multiple sizes allocation in WorkspaceManager. Signed-off-by:
Ming Huang <mingh@nvidia.com> * [JAX] Use template and ariadic arguments to improve multple sizes allocator. Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com>
-
- 25 Aug, 2023 1 commit
-
-
zlsh80826 authored
* Fused attention kernel only supports sm80 and sm90 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update transformer_engine/jax/csrc/modules.cpp Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * arbitary fused kernel supports sm86/sm89 after 8.9.3 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Skip sm70 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Forward is_fused_attn_kernel_available to cpp backend Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove cpp is_fused_attn_available API Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com>
-
- 09 Aug, 2023 1 commit
-
-
Ming-Xu Huang authored
* Initially commit for FSDP Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding support to fsdp xmap sharding Signed-off-by:
Ming Huang <mingh@nvidia.com> * Specify WeightHParamsCollection of fp8 meta. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Support partial FP8 custom calls with FSDP. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding amax reduction on the fsdp mesh dim. Signed-off-by:
Ming Huang <mingh@nvidia.com> * clean code Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix the wrong batch axis in logic_axis_rules and add sharding_constraint to BMM1 Signed-off-by:
Ming Huang <mingh@nvidia.com> * Support FSDP in fMHA. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix missing all-reduce of wgrads along FSDP axis. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Change default value of fsdp_axis_name to for aligning with others Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix RuntimeError: with_sharding_constraint requires a non-empty Signed-off-by:
Ming Huang <mingh@nvidia.com> * Slightly changes (review feedback) Signed-off-by:
Ming Huang <mingh@nvidia.com> * Removed unnecessary comments Signed-off-by:
Ming Huang <mingh@nvidia.com> * Mergeing input_dp_dim into weight_fsdp_dim_map Signed-off-by:
Ming Huang <mingh@nvidia.com> * Update transformer_engine/jax/sharding.py Signed-off-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com>
-
- 07 Aug, 2023 1 commit
-
-
zlsh80826 authored
* Fix flash attention dropout probability with inference Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add output as the fused attention ctx tensor Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add rng_state as the fused attention ctx tensors Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add flash attention supported lengths to the fused attention Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refactor attention primitive to reuse abstract shaped array Signed-off-by:
Reese Wang <rewang@nvidia.com> * Detect backend type to allocate appropriate ctx size Signed-off-by:
Reese Wang <rewang@nvidia.com> * Skip dropout correctness instead of return success Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use cudaMemsetAsync and enhance the error handling Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add flash attention kernel elts_per_thread update Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove redundant max 512 suffix Signed-off-by:
Reese Wang <rewang@nvidia.com> * Keep only DType and remove NVTEDType from python Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix a float32_attention_logits bugs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Re-calculate workspace size for self attention Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance bias/dbias shape guard Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the seed/rng_state checker Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use jax.core.ShapedArray as jax.abstract_arrays is deprecated Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the unittest docs Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 18 Jul, 2023 1 commit
-
-
zlsh80826 authored
* Fully remove attn_type and set self_attn_mask_type default to 'causal' Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix tests with new arguments Signed-off-by:
Reese Wang <rewang@nvidia.com> * Explicit self_attn_mask_type for examples Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update transformer_engine/jax/flax/transformer.py Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
zlsh80826 <rewang@nvidia.com> * Update transformer_engine/jax/flax/transformer.py Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
zlsh80826 <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Signed-off-by:
zlsh80826 <rewang@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 07 Jul, 2023 1 commit
-
-
Ming-Xu Huang authored
Signed-off-by:Ming Huang <mingh@nvidia.com>
-
- 20 Jun, 2023 2 commits
-
-
zlsh80826 authored
* Enable fused attention dropout Signed-off-by:
Reese Wang <rewang@nvidia.com> * Cast the uint32 key/counter to int64 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update dropout support in fused attention docs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Revise devPtrCuSeqlen* to align the naming Signed-off-by:
Reese Wang <rewang@nvidia.com> * Support different Jax PRNG impls Signed-off-by:
Reese Wang <rewang@nvidia.com> * Revert CastAsync since it is not used Signed-off-by:
Reese Wang <rewang@nvidia.com> * Implement is_training for 16-bit fused attn Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add fused attn with dropout sanity unit tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the comments readability and rng_state checker Signed-off-by:
Reese Wang <rewang@nvidia.com> * Change the attention dropout shape to align other frameworks Signed-off-by:
Reese Wang <rewang@nvidia.com> * Make encoder tests deterministic Signed-off-by:
Reese Wang <rewang@nvidia.com> * Change the default seed for the jax encoder tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Maintain offset in TE Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the resource safety Signed-off-by:
Reese Wang <rewang@nvidia.com> * Revert rng_state type to allow only i64 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Handle the corner case for elts_per_threads calculation Signed-off-by:
Reese Wang <rewang@nvidia.com> * Populate rng state by kernels Signed-off-by:
Reese Wang <rewang@nvidia.com> * Rename rng_state as seed in cpp_extensions Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update the attention dropout comment Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
zlsh80826 authored
* Add self_attn_mask_type and replace attn_type Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refine the keyword style for the better readability Signed-off-by:
Reese Wang <rewang@nvidia.com> * Replace attn_type with attn_mask_type in praxis transformer Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix typos Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 07 Jun, 2023 1 commit
-
-
Frédéric Bastien authored
* Use the same default in the function to what the class default. Signed-off-by:
Frederic Bastien <fbastien@nvidia.com> * Assert instead of silently ignoring not supported variation. Small doc correction, amax_compute_algo is partially supported. Signed-off-by:
Frederic Bastien <fbastien@nvidia.com> * Fix line lenght to fix the CI. Signed-off-by:
Frederic Bastien <fbastien@nvidia.com> * Apply suggestions from code review Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Frédéric Bastien <frederic.bastien@gmail.com> * grammar Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Clarify that it is only TE/JAX that don't support that faeture. Signed-off-by:
Frederic Bastien <fbastien@nvidia.com> * Update transformer_engine/jax/fp8.py Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
Frédéric Bastien <frederic.bastien@gmail.com> * Update the test following the change in default value Signed-off-by:
Frederic Bastien <fbastien@nvidia.com> * Fix ci Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Frederic Bastien <fbastien@nvidia.com> Signed-off-by:
Frédéric Bastien <frederic.bastien@gmail.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 06 Jun, 2023 1 commit
-
-
Ming-Xu Huang authored
Signed-off-by:Ming Huang <mingh@nvidia.com>
-
- 02 Jun, 2023 1 commit
-
-
Jan Bielak authored
* Ignore IDE files Signed-off-by:
Jan Bielak <jbielak@nvidia.com> * Fix typing errors Signed-off-by:
Jan Bielak <jbielak@nvidia.com> * Ignore devcontainer files Signed-off-by:
Jan Bielak <jbielak@nvidia.com> * Avoid import from private module Signed-off-by:
Jan Bielak <jbielak@nvidia.com> * Apply @timmoon10 's suggestions Signed-off-by:
Jan Bielak <jbielak@nvidia.com> --------- Signed-off-by:
Jan Bielak <jbielak@nvidia.com>
-
- 31 May, 2023 1 commit
-
-
Tim Moon authored
* Refactor Setuptools build system Successfully launches CMake install, but installs CMake extensions in temp dir. Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Debug JAX build Fix pybind11 import. Distinguish between build-time and run-time dependencies. Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Add helper function to determine dependencies Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Add missing license Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Debug case where system CMake is too old Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Add missing license Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Simplify sanity import tests Just importing modules provides richer error messages. Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Properly install submodules Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Install helper library for TensorFlow Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Update documentation Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Do not install Ninja by default Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Include Git commit hash in version string Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Override build_ext.build_extensions instead of build_ext.run Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Fix incorrect include path Restore Ninja dependency. Restore overriding build_ext.run func. Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Review suggestions from @nouiz Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Disable parallel Ninja jobs in GitHub actions Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Properly install userbuffers lib Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Tweak install docs Review suggestion from @ksivaman Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Add examples for specifying framework in docs Signed-off-by:
Tim Moon <tmoon@nvidia.com> --------- Signed-off-by:
Tim Moon <tmoon@nvidia.com>
-
- 23 May, 2023 1 commit
-
-
zlsh80826 authored
* Unfused scale+softmax if bias is present Signed-off-by:
Reese Wang <rewang@nvidia.com> * WAR a causal masking + no_bias bug and add the unittests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix the optional args (bias) sharding Signed-off-by:
Reese Wang <rewang@nvidia.com> * Disable fused attn in JAX by default, enable it with NVTE_USE_FUSED_ATTN Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add thread local for the plan cache Signed-off-by:
Reese Wang <rewang@nvidia.com> * Rename dbeta to dbias for the readability Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add scaled softmax with dropout test cases Signed-off-by:
Reese Wang <rewang@nvidia.com> * Updated NVTE_FUSED_ATTN variable name Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-