- 28 Feb, 2024 1 commit
-
-
cyanguwa authored
* added support for arbitrary bias shapes for fused_attn Signed-off-by:
Alp Dener <adener@nvidia.com> * Fix linting Signed-off-by:
Alp Dener <adener@nvidia.com> * Add b1ss/bhss/11ss bias shapes when not requiring dBias Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add bias_b/h to plan cache Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fixed compile errors after PR653 merge Signed-off-by:
Alp Dener <adener@nvidia.com> * updated JAX unittests for new bias shapes Signed-off-by:
Alp Dener <adener@nvidia.com> * fixed mismatched mask type checking Signed-off-by:
Alp Dener <adener@nvidia.com> * corrected skip condition Signed-off-by:
Alp Dener <adener@nvidia.com> * fix selection logic for A100s Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * corrected skip checks for bias shapes Signed-off-by:
Alp Dener <adener@nvidia.com> * resolved test issues but neginf with float16 is still problematic with JAX Signed-off-by:
Alp Dener <adener@nvidia.com> * new bias shapes passing TE JAX CI for seqlen <= 512, seq_q == seq_kv and h_q == h_kv conditions Signed-off-by:
Alp Dener <adener@nvidia.com> * TE/JAX fused attn tests for new bias shapes passing with neg_inf=-2**27 for Bfloat16 and -2**15 for Float16 Signed-off-by:
Alp Dener <adener@nvidia.com> * code style fixes and test parameter ID cleanup Signed-off-by:
Alp Dener <adener@nvidia.com> * fixed incorrect skip condition for backward fused attn test Signed-off-by:
Alp Dener <adener@nvidia.com> --------- Signed-off-by:
Alp Dener <adener@nvidia.com> Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by:
Alp Dener <adener@nvidia.com>
-
- 22 Feb, 2024 1 commit
-
-
Reese Wang authored
* Refine MHA API Signed-off-by:
Reese Wang <rewang@nvidia.com> * Reuse func from the flax Signed-off-by:
Reese Wang <rewang@nvidia.com> * DPA draft Signed-off-by:
Reese Wang <rewang@nvidia.com> * qkv packed draft Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix test_layer with fused attn Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add attn_bias_type and enhance a few code flow Signed-off-by:
Reese Wang <rewang@nvidia.com> * Move scale_factor from __call__ to init Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the docs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add DPA public API and tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refine docs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refine docs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix conflict Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add qkv separate fused attn Signed-off-by:
Reese Wang <rewang@nvidia.com> * Apply BSHD_BSHD_BSHD format Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove debug log Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add fused attention layer tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add NVTE_FUSED_ATTN docs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fine-grained fused attn settings Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove the default value of num_attetnion_head and head_dim Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add teardown for fused attn env Signed-off-by:
Reese Wang <rewang@nvidia.com> * Unify the Optional notation Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix Pre/Post scale bias comments Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add no_mask tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add checkpoint_name for fused attn Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix the fused attn batcher Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 08 Feb, 2024 1 commit
-
-
cyanguwa authored
* test alibi between fa and fu Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * move alibi slopes and bias to global to avoid repeating calculation Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix alibi slopes/bias generation Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix _is_flash_attention_supported to allow alibi type Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * disable padding mask when alibi is used for fused attn arbi backend Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add support for custom [n_heads] alibi_slopes in flash, fused, unfused attention Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up last commit Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove alibi_type=none tests as they are unnecessary Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update cudnn-frontend to 1.0.2 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * change bias/dbias shape to allow b,1/1,h/b,h in arbi backend Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak tests for arbi post_scale_bias [1,h,s,s] or alibi_slopes [n_heads] Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * change bias/dbias shape in max512 backend - incomplete Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove max512 changes from last commit and disable max512 (and arbi temporarily) for [b, h, s, s]; pending cuDNN backend support Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up and tweak backend selection logic Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * replace || with () in docstring Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix bias shape for max512 backend Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * combine slopes/bias generation to one function get_alibi() and fix alibi tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix PR557 bugs Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Update transformer_engine/pytorch/attention.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
cyanguwa <8636796+cyanguwa@users.noreply.github.com> * encapsulate global alibi tensors into a dict cache Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * reduce alibi slopes test size Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update to cudnn-frontend 1.0.3 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * use dBias shape to define bias_b/bias_h because jax materializes dBias rather than Bias in bwd abstract Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by:
cyanguwa <8636796+cyanguwa@users.noreply.github.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com>
-
- 03 Feb, 2024 1 commit
-
-
cyanguwa authored
* Update cudnn frontend to 1.0.3 to fix cudnn v9 Nans Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * make d_out contiguous for bwd Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove cudnnDestroy to let torch handle it Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Update transformer_engine/pytorch/attention.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
cyanguwa <8636796+cyanguwa@users.noreply.github.com> * Update transformer_engine/pytorch/attention.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
cyanguwa <8636796+cyanguwa@users.noreply.github.com> * Update transformer_engine/pytorch/attention.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
cyanguwa <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by:
cyanguwa <8636796+cyanguwa@users.noreply.github.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com>
-
- 29 Jan, 2024 1 commit
-
-
Alp Dener authored
* Removed cudaMalloc/WorkspaceManager in JAX csrc. JAX custom ops now request buffers from XLA for their workspace tensors. Signed-off-by:
Alp Dener <adener@nvidia.com> * removed unused GEMM C++ API in TE-JAX Signed-off-by:
Alp Dener <adener@nvidia.com> * fixed typo in layernorm_geglu_fp8_mlp and removed unnecessary shape reductions in primitives Signed-off-by:
Alp Dener <adener@nvidia.com> * fixed import order for linting Signed-off-by:
Alp Dener <adener@nvidia.com> * fixed custom op errors due to incorrect static arg nums in JAX jit Signed-off-by:
Alp Dener <adener@nvidia.com> * shifted cudnnSetStream further down the kernel to avoid error when executing dummy kernel call with nullptr stream Signed-off-by:
Alp Dener <adener@nvidia.com> * fixed linting errors for blank lines Signed-off-by:
Alp Dener <adener@nvidia.com> --------- Signed-off-by:
Alp Dener <adener@nvidia.com>
-
- 16 Jan, 2024 1 commit
-
-
zlsh80826 authored
* Support num_gqa_groups arguments Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add GQA support on the JAX bridge code Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix the kv stride of the arbitrary backend Signed-off-by:
Reese Wang <rewang@nvidia.com> * Complete rewrite fused attention tests and add GQA coverage Signed-off-by:
Reese Wang <rewang@nvidia.com> * Support unfused GQA Signed-off-by:
Reese Wang <rewang@nvidia.com> * Calculate seqlen before the primitive for the better perf Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add GQA layer tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Apply code style checks for te_jax Signed-off-by:
Reese Wang <rewang@nvidia.com> * Apply code style checks for tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add num_gqa_groups doc Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refine the qkv_type Signed-off-by:
Reese Wang <rewang@nvidia.com> * Correct the variable naming Signed-off-by:
Reese Wang <rewang@nvidia.com> * Handle Max512 CAUSAL Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add WAR for the latest jax image Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 08 Jan, 2024 1 commit
-
-
cyanguwa authored
fix FP8 dims Signed-off-by:Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
-
- 03 Jan, 2024 1 commit
-
-
Przemyslaw Tredak authored
Signed-off-by:Przemek Tredak <ptredak@nvidia.com>
-
- 13 Dec, 2023 1 commit
-
-
cyanguwa authored
* fix backend selection for sm80 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix compiling warnings in sdpa flash Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add nvte error messages Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add NVTE_CHECK_CUDNN_FE for error messaging Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * disable pylint bare-except Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
-
- 07 Dec, 2023 1 commit
-
-
cyanguwa authored
* Integrate cuDNN frontend v1 to fused attention and miscellaneous fixes Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix jax/paddle for unit tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix jax/pytorch lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * simplify stride generation Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix and/or logic in get_backend Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix flag_max512 and test_numerics Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove v.contiguous() since get_qkv_layout covers it Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * skip fp8 tests for sm89 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * further fix jax CI Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix jax CI Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert mask type to comma-separated list Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix last two commits Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * integrate v1/pre-release-5 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * cleanup prerelease5 integration and fix FA2.1 commit Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * force dropout to 0 if not training Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix Jax CI Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * testing bias/alibi and padding+causal; add alibi to unfused DPA Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * set flag_arb to false when non determinism is not allowed Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * followup on prev commit; remove redundant python env var setting Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * WIP: minor tweaks for tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * prepare for tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix determinism logic for fused attn Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add bias to bwd Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix gpt_checkpointing/dpa_accuracy problem Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix some seg fault issues Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add failure notes Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove use of non-deter var for backend selection Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fix for lint and CI Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix workspace size in bwd and uncomment bias test Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix get_alibi and remove check_support Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update tests status Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove workspace_opt from FADescriptor_v1 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * disable arbitrary backend + post scale bias in Jax; waiting on PR 525 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up bhsd order Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * swap bias/rng_state order in aux_ctx_tensor and add bias to aux_ctx_tensor in _qkvpacked/_kvpacked API Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove support for padding_causal + cross for max512 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * change alibi bias to float32 for bias_1_4/5 tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * further clean up tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix thd fwd output shape for FlashAttention and add backend info for DPA Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix definition of workspace limit when dbias is present Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * further tweak DP_WORKSPACE_LIMIT definition Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * disallow alibi+no_mask for sdpa flash and update alibi tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update jax/paddle after PR525 and fix DP_WORKSPACE_LIMIT for dbias Jax tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * disable dbias for non-hopper archs Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix layernorm lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remode unused arg for lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove build dir in setup.py Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * change selection logic to prefer fused attn on sm90 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix distributed jax test Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix h and s order in header Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update to cudnn fe v1 branch Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove manual setting of workopt path due to dbias after v1 update Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix paddle CI Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add post_scale_bias and alibi to sdpa flash support matrix Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix support matrix in header files Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * move headers back to .cu and change seed/offset to int64 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update Megatron commit in L1 test and remove all prints in fused attn test Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix L1 Megatron test Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fp8 arg in L1 Megatron script Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * print only when debug flag is on Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove checkpointing loading to avoid loading other tests results Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by:
cyanguwa <8636796+cyanguwa@users.noreply.github.com>
-
- 13 Nov, 2023 1 commit
-
-
zlsh80826 authored
[C/JAX] Support more mask types for the arbitrary seqlen kernels and minor changes of JAX bias (#469) * Move bias to float32 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enable varlen Signed-off-by:
Reese Wang <rewang@nvidia.com> * Increase neg infinity abs values Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enable varlen tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove unnecessary code Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix lint Signed-off-by:
Reese Wang <rewang@nvidia.com> * Support variable sequence length after cuDNN 8.9.6 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use unique_ptr instead of shared_ptr Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add a new mask type: PADDING_CAUSAL_MASK Signed-off-by:
Reese Wang <rewang@nvidia.com> * Support flash padding mask after 8.9.6 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the Max512 handling for causal masking and add the related tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update the fused attn support lists Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove padding_aware from the caching Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix libtransformer.so issue Signed-off-by:
Reese Wang <rewang@nvidia.com> * Reduce the pad ratio tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix a bug with cuDNN 8.9.5 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Release backend resource after the module level unit test Signed-off-by:
Reese Wang <rewang@nvidia.com> * Clean the jax live arrays before running the unit tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix too-few-public-methods lint Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 08 Nov, 2023 1 commit
-
-
zlsh80826 authored
* Deprecate QKV_INTERLEAVED use in JAX Signed-off-by:
Reese Wang <rewang@nvidia.com> * Deprecate QKV_INTERLEAVED use in Paddle Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance qkv enum mappings Signed-off-by:
rewang <rewang@nvidia.com> * Fix LD_LIBRARY_PATH issue Signed-off-by:
rewang <rewang@nvidia.com> * Arbitrary seqlen kernels only support self attention currently Signed-off-by:
rewang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Signed-off-by:
rewang <rewang@nvidia.com>
-
- 09 Oct, 2023 1 commit
-
-
cyanguwa authored
* add support for h2d/2hd in 8.9.6 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * cull unit tests in fused_attn.py and add skipif for layout tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add workopt=1 flag for dpa tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update support table for arbi_seqlen backend Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix rotary position embedding and add unit tests accordingly Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * further cut down unit tests for CI efficiency Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove einops dependency Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
-
- 25 Sep, 2023 1 commit
-
-
cyanguwa authored
* add flexible layout support Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add support for flexible qkv layout Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add more changes Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fixes for compiling Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove redudant file Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix options device error Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix typos Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * more changes; WIP Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * more changes; WIP Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fixes and tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fixes and wrong results Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * sb3hd/bs3hd working on top of 3xsbhd/bshd/thd Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix dQ, dK, dV Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add nvtx Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove qkvso_strides on torch side; cover it in generateQKVStrides Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * all 15 layouts pass Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add workspace optimization Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes and test Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * removed most debug info/clean up Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add note to deprecate some qkv layouts Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix code for unit tests in test_fused_attn.py Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * further remove debug info Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove a couple more comments Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix numerics tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fixes for lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fp8 tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix onnx for core attn; not fixed Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove nvtx and add env var for workspace opt Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove testing for env var Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * replace zeros/zeros_like with empty/empty_like Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix nvtx marker name for _q_k_v API Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove sm80 when compiling for h100 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add mapping from qkv layout to layout group and qkv format Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up enums mapping and remove trailing spaces Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * simplify workspace opt control logic; only need env var Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fp8 test, and minor modifications for other tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * avoid overwriting model configs in unit test Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * random fixes/improvements: get_qkv_format/etc, default values, docstrings, comments Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix minor issues: invalid syntax Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * change workspace opt logic back to FORCE_WORKSPACE_OPT Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix FP8 tests and generateStrides function Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix get_backend logic for max512/arbitrary Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix unit tests; need cleanup Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up unit tests for layouts, and fix minor lint issue Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor tweaks for CI testing: onnx string issue and test fused attn first Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove one unsupported layout from max512 and add a check to qkvpacked API Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix te layer test; reduce test time Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert compiler option changes; add back sm80 for even h100 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove some unit tests or make them optional to reduce CI time Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove more unit tests temporarily Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove _q_k_v in naming and add NVTE_ERROR for FP8 Aux_CTX_Tensors size checks Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add more deprecation notes Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove temp tests from last commit Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * replace with te::getenv Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove prints from last commit Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove redundant contiguous() Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove thd->bs3hd user warning to avoid GPU sync Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * adjust fused attn bs in tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * temporary fix for onnx issue; more fixes in PR 437 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove unused variables Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by: Charlene Yang Signed-off-by:
cyanguwa <8636796+cyanguwa@users.noreply.github.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 12 Sep, 2023 1 commit
-
-
cyanguwa authored
* add workspace optimization for arbitrary_seqlen fused attn Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix whitespace for lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add use_workspace_opt to cudnn plan cache and fix workspace estimate Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * modify workspace opt logic; move zero fill to FP8 API only; other minor fixes Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix try/catch Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix std string error when input is nullptr Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove comments Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Add = for required vs allowed workspace comparison Signed-off-by:
cyanguwa <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by:
cyanguwa <8636796+cyanguwa@users.noreply.github.com>
-
- 25 Aug, 2023 1 commit
-
-
zlsh80826 authored
* Fused attention kernel only supports sm80 and sm90 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update transformer_engine/jax/csrc/modules.cpp Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * arbitary fused kernel supports sm86/sm89 after 8.9.3 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Skip sm70 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Forward is_fused_attn_kernel_available to cpp backend Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove cpp is_fused_attn_available API Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com>
-
- 11 Aug, 2023 1 commit
-
-
cyanguwa authored
* miscellenous fixes Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add back pytorch csrc extensions.h Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add unit tests for dpa checkpointing Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove seqlen%32/64 checks for now Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix tests for core attn bias Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add tests for changes regarding rng_state in aux_ctx_tensor Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * reuse rng tracker from numerics in fused attn; skip checkpointing if FAv2 in numerics Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * uncomment comments used for testing Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix pre/post scale bias Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Update transformer_engine/pytorch/attention.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
cyanguwa <8636796+cyanguwa@users.noreply.github.com> * remove skipifs for FAv2 check after PR366 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove checkpointing tests for transformer layer; dpa tests still provide coverage Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * adjust random number range for tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Add upper bound to FA version Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Check backend only when using FusedAttention Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * remove imports/variables related to FAv2 checks Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * further fix random number ranges for tests Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix variable referenced before assignment error Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by:
cyanguwa <8636796+cyanguwa@users.noreply.github.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 07 Aug, 2023 1 commit
-
-
zlsh80826 authored
* Fix flash attention dropout probability with inference Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add output as the fused attention ctx tensor Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add rng_state as the fused attention ctx tensors Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add flash attention supported lengths to the fused attention Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refactor attention primitive to reuse abstract shaped array Signed-off-by:
Reese Wang <rewang@nvidia.com> * Detect backend type to allocate appropriate ctx size Signed-off-by:
Reese Wang <rewang@nvidia.com> * Skip dropout correctness instead of return success Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use cudaMemsetAsync and enhance the error handling Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add flash attention kernel elts_per_thread update Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove redundant max 512 suffix Signed-off-by:
Reese Wang <rewang@nvidia.com> * Keep only DType and remove NVTEDType from python Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix a float32_attention_logits bugs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Re-calculate workspace size for self attention Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance bias/dbias shape guard Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the seed/rng_state checker Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use jax.core.ShapedArray as jax.abstract_arrays is deprecated Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the unittest docs Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 14 Jul, 2023 1 commit
-
-
cyanguwa authored
* Fix bprop for cuDNN 8.9.3 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Update cuDNN version requirement to 8.9.3 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * debug paddle CI Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * debug paddle CI; force LD_LIBRARY Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * debug paddle CI; force LD_LIBRARY to /opt Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove debug info for paddle Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * change cudnn requirement to 8.9.1 for v1 and 8.9.0 for v2; add batch size 32 for unit test; add LD library path for paddle tests temporarily Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove printf line in fused_attn.cpp Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add batch size 32 for unit test Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update cudnn-frontend to 0.9.2 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove temporary LD library path used for testing pre-released cudnn 8.9.3 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
-
- 22 Jun, 2023 1 commit
-
-
cyanguwa authored
* add long sequence support and unify three backends for fused attention Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update cudnn-frontend to v0.9.1 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * replace cpu_float2half_rn with __float2half_rn Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix backend selection and NVTEDType Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix ci Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * make cudnn plan caches thread_local Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix CI Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * replace cuDNN throw with NVTE_CHECK Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix replacement of cuDNN throw with NVTE_CHECK Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * force dropout probablity to 0 in inference mode Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * change negInfinity to be consistent with m512 fused attn Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove float2half conversion for scale_dropout Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add back runtime api for sm detection Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add gemm3 to enums FP8Fwd/BwdTensors Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * change dropout from no to yes for fmha_v1 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove output_rng_state in m512 kernels Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix elts_per_thread calculation in kvpacked fwd Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove dropout=0.0 restriction for m512 fused attn Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove output_rng_state completely from m512 kernels Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 20 Jun, 2023 1 commit
-
-
zlsh80826 authored
* Enable fused attention dropout Signed-off-by:
Reese Wang <rewang@nvidia.com> * Cast the uint32 key/counter to int64 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update dropout support in fused attention docs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Revise devPtrCuSeqlen* to align the naming Signed-off-by:
Reese Wang <rewang@nvidia.com> * Support different Jax PRNG impls Signed-off-by:
Reese Wang <rewang@nvidia.com> * Revert CastAsync since it is not used Signed-off-by:
Reese Wang <rewang@nvidia.com> * Implement is_training for 16-bit fused attn Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add fused attn with dropout sanity unit tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the comments readability and rng_state checker Signed-off-by:
Reese Wang <rewang@nvidia.com> * Change the attention dropout shape to align other frameworks Signed-off-by:
Reese Wang <rewang@nvidia.com> * Make encoder tests deterministic Signed-off-by:
Reese Wang <rewang@nvidia.com> * Change the default seed for the jax encoder tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Maintain offset in TE Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the resource safety Signed-off-by:
Reese Wang <rewang@nvidia.com> * Revert rng_state type to allow only i64 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Handle the corner case for elts_per_threads calculation Signed-off-by:
Reese Wang <rewang@nvidia.com> * Populate rng state by kernels Signed-off-by:
Reese Wang <rewang@nvidia.com> * Rename rng_state as seed in cpp_extensions Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update the attention dropout comment Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 06 Jun, 2023 1 commit
-
-
cyanguwa authored
* fix headers for doxygen Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix description f16 and use half precision instead Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
cyanguwa <8636796+cyanguwa@users.noreply.github.com> * fix lint Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by:
cyanguwa <8636796+cyanguwa@users.noreply.github.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 23 May, 2023 1 commit
-
-
zlsh80826 authored
* Unfused scale+softmax if bias is present Signed-off-by:
Reese Wang <rewang@nvidia.com> * WAR a causal masking + no_bias bug and add the unittests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix the optional args (bias) sharding Signed-off-by:
Reese Wang <rewang@nvidia.com> * Disable fused attn in JAX by default, enable it with NVTE_USE_FUSED_ATTN Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add thread local for the plan cache Signed-off-by:
Reese Wang <rewang@nvidia.com> * Rename dbeta to dbias for the readability Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add scaled softmax with dropout test cases Signed-off-by:
Reese Wang <rewang@nvidia.com> * Updated NVTE_FUSED_ATTN variable name Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 09 May, 2023 1 commit
-
-
zlsh80826 authored
* Add fused attention unit tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use NVTE_* enums Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use NVTE_Mask_Type and remove FMHADescriptor Signed-off-by:
Reese Wang <rewang@nvidia.com> * Move common functions to utils Signed-off-by:
Reese Wang <rewang@nvidia.com> * Change namespace to fused_attn Signed-off-by:
Reese Wang <rewang@nvidia.com> * Move fused_attn_max_512_fwd_qkvpacked under the general APIs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add fused_attn_max_512_bwd_qkvpacked Signed-off-by:
Reese Wang <rewang@nvidia.com> * Move fused_attn_max_512_bwd_qkvpacked under the general APIs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove redundant blank line Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix a potential bug for cu_seqlen converter Signed-off-by:
Reese Wang <rewang@nvidia.com> * Reformat fused_attn_max_512 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refine the unfused attention warning message Signed-off-by:
Reese Wang <rewang@nvidia.com> * Rename to fused_attn_max_512 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove the deprecated header Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix flax import Signed-off-by:
Reese Wang <rewang@nvidia.com> * Rename to fused attn Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add attention related mask Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add attn_mask_type and attn_bias_type Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refactor jax primitive API * Merge q_cu_seqlen and kv_cu_seqlen * Remove is_causal_masking * Replace seed with rng_state * Add is_training argument Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove dsoftmax from the customcall Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add None guard for bias and dropout_rng Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add version guard Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add is_fused_attn_kernel_available() to correctly dispatch the attention impl Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix the merge conflict Signed-off-by:
Reese Wang <rewang@nvidia.com> * Adjust the code style Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add the missing blank lines Signed-off-by:
Reese Wang <rewang@nvidia.com> * Change the order of FADescriptor members Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the readability of fused_attn_max_512.cu Signed-off-by:
Reese Wang <rewang@nvidia.com> * Generalize the input dimension unpacking Signed-off-by:
Reese Wang <rewang@nvidia.com> * 16 bits fused attention requires 8.9.1 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update fused attention support matrix Signed-off-by:
Reese Wang <rewang@nvidia.com> * Handle None type when sharding Signed-off-by:
Reese Wang <rewang@nvidia.com> * Change to the padding ratio Signed-off-by:
Reese Wang <rewang@nvidia.com> * Performance optimization for non-bias cases Signed-off-by:
Reese Wang <rewang@nvidia.com> * Revert the cudnn-frontend PRIVATE keyword which was used for debugging Signed-off-by:
Reese Wang <rewang@nvidia.com> * Revert "Update fused attention support matrix" This reverts commit 4effe67d0f08f733919a329ce5ab421958740f4a. Signed-off-by:
Reese Wang <rewang@nvidia.com> * Treat b * s as total_seqs to align ragged cases Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add FP16/BF16 max_seqlen <= 512 fused attention to the support matrix Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refine test_fused_attn.py * Replace reference code with flax.linen * Remove unnecessary comments * Use AttnMaskType Signed-off-by:
Reese Wang <rewang@nvidia.com> * Unify the cuDNN compile version Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add dropout to the support matrix Signed-off-by:
Reese Wang <rewang@nvidia.com> * Slightly adjust the headers Signed-off-by:
Reese Wang <rewang@nvidia.com> * Typo fix: remove redundant either Signed-off-by:
Reese Wang <rewang@nvidia.com> * Consolidating fused attention requirements Signed-off-by:
Reese Wang <rewang@nvidia.com> * Replace cudnn_frontend::throw_if with NVTE_CHECK for the better error line report Signed-off-by:
Reese Wang <rewang@nvidia.com> * Rename to fused_attn_fp16_bf16_max_seqlen_512 for the better readability Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove CUDNN_FRONTEND_UNUSED Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add more annotations to the custom calls Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 02 May, 2023 1 commit
-
-
cyanguwa authored
* move dbias from input list to output list for bwd Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * split asserts into three for bias checks Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Update transformer_engine/pytorch/cpp_extensions.py Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
cyanguwa <8636796+cyanguwa@users.noreply.github.com> * fix asserts for bias checks Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * another fix for asserts for bias checks Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by:
cyanguwa <8636796+cyanguwa@users.noreply.github.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com>
-
- 22 Apr, 2023 1 commit
-
-
cyanguwa authored
remove used function ternary_pw_op_create Signed-off-by:
Charlene Yang <charleney@nvidia.com> Co-authored-by:
Charlene Yang <charleney@nvidia.com>
-
- 21 Apr, 2023 1 commit
-
-
cyanguwa authored
* Add FP8 fused attention to TE for PyTorch Signed-off-by:
Charlene Yang <charleney@nvidia.com> * add license for cudnn-frontend, modify installation requirements, and refactor some headers for aesthetics Signed-off-by:
Charlene Yang <charleney@nvidia.com> * add c api docs for fused attention Signed-off-by:
Charlene Yang <charleney@nvidia.com> * add exception for unsupported precision/sequence length combinations Signed-off-by:
Charlene Yang <charleney@nvidia.com> * fix installation requirement for non fused attn use cases Signed-off-by:
Charlene Yang <charleney@nvidia.com> * fix docs for fused-attn Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * prefix enums with NVTE_ and replace old MHA_Matrix with NVTE_QKV_Matrix Signed-off-by:
Charlene Yang <charleney@nvidia.com> * minor fixes based on PR comments Signed-off-by:
Charlene Yang <charleney@nvidia.com> * fix description for kvpacked fwd Signed-off-by:
Charlene Yang <charleney@nvidia.com> * fix description of Bias in C api Signed-off-by:
Charlene Yang <charleney@nvidia.com> * minor fixes for cudnn requirement and description for QKV tensors Signed-off-by:
Charlene Yang <charleney@nvidia.com> * fix QKV layout description and support matrix for C api Signed-off-by:
Charlene Yang <charleney@nvidia.com> * add asserts to cpp_extensions for qkv layout/bias type/attn mask type Signed-off-by:
Charlene Yang <charleney@nvidia.com> * fix typo precision Signed-off-by:
Charlene Yang <charleney@nvidia.com> --------- Signed-off-by:
Charlene Yang <charleney@nvidia.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Charlene Yang <charleney@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-