- 08 Nov, 2023 1 commit
-
-
zlsh80826 authored
* Deprecate QKV_INTERLEAVED use in JAX Signed-off-by:
Reese Wang <rewang@nvidia.com> * Deprecate QKV_INTERLEAVED use in Paddle Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance qkv enum mappings Signed-off-by:
rewang <rewang@nvidia.com> * Fix LD_LIBRARY_PATH issue Signed-off-by:
rewang <rewang@nvidia.com> * Arbitrary seqlen kernels only support self attention currently Signed-off-by:
rewang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Signed-off-by:
rewang <rewang@nvidia.com>
-
- 20 Oct, 2023 1 commit
-
-
Alp Dener authored
fixed incorrect of extend_fsdp_sharding_meta() in cross_fused_attn() Signed-off-by:Alp Dener <adener@nvidia.com>
-
- 25 Aug, 2023 1 commit
-
-
zlsh80826 authored
* Fused attention kernel only supports sm80 and sm90 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update transformer_engine/jax/csrc/modules.cpp Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * arbitary fused kernel supports sm86/sm89 after 8.9.3 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Skip sm70 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Forward is_fused_attn_kernel_available to cpp backend Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove cpp is_fused_attn_available API Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com>
-
- 09 Aug, 2023 1 commit
-
-
Ming-Xu Huang authored
* Initially commit for FSDP Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding support to fsdp xmap sharding Signed-off-by:
Ming Huang <mingh@nvidia.com> * Specify WeightHParamsCollection of fp8 meta. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Support partial FP8 custom calls with FSDP. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Adding amax reduction on the fsdp mesh dim. Signed-off-by:
Ming Huang <mingh@nvidia.com> * clean code Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix the wrong batch axis in logic_axis_rules and add sharding_constraint to BMM1 Signed-off-by:
Ming Huang <mingh@nvidia.com> * Support FSDP in fMHA. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix missing all-reduce of wgrads along FSDP axis. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Change default value of fsdp_axis_name to for aligning with others Signed-off-by:
Ming Huang <mingh@nvidia.com> * Fix RuntimeError: with_sharding_constraint requires a non-empty Signed-off-by:
Ming Huang <mingh@nvidia.com> * Slightly changes (review feedback) Signed-off-by:
Ming Huang <mingh@nvidia.com> * Removed unnecessary comments Signed-off-by:
Ming Huang <mingh@nvidia.com> * Mergeing input_dp_dim into weight_fsdp_dim_map Signed-off-by:
Ming Huang <mingh@nvidia.com> * Update transformer_engine/jax/sharding.py Signed-off-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Signed-off-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com>
-
- 07 Aug, 2023 1 commit
-
-
zlsh80826 authored
* Fix flash attention dropout probability with inference Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add output as the fused attention ctx tensor Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add rng_state as the fused attention ctx tensors Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add flash attention supported lengths to the fused attention Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refactor attention primitive to reuse abstract shaped array Signed-off-by:
Reese Wang <rewang@nvidia.com> * Detect backend type to allocate appropriate ctx size Signed-off-by:
Reese Wang <rewang@nvidia.com> * Skip dropout correctness instead of return success Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use cudaMemsetAsync and enhance the error handling Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add flash attention kernel elts_per_thread update Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove redundant max 512 suffix Signed-off-by:
Reese Wang <rewang@nvidia.com> * Keep only DType and remove NVTEDType from python Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix a float32_attention_logits bugs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Re-calculate workspace size for self attention Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance bias/dbias shape guard Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the seed/rng_state checker Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use jax.core.ShapedArray as jax.abstract_arrays is deprecated Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the unittest docs Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-
- 20 Jun, 2023 1 commit
-
-
zlsh80826 authored
* Enable fused attention dropout Signed-off-by:
Reese Wang <rewang@nvidia.com> * Cast the uint32 key/counter to int64 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update dropout support in fused attention docs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Revise devPtrCuSeqlen* to align the naming Signed-off-by:
Reese Wang <rewang@nvidia.com> * Support different Jax PRNG impls Signed-off-by:
Reese Wang <rewang@nvidia.com> * Revert CastAsync since it is not used Signed-off-by:
Reese Wang <rewang@nvidia.com> * Implement is_training for 16-bit fused attn Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add fused attn with dropout sanity unit tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the comments readability and rng_state checker Signed-off-by:
Reese Wang <rewang@nvidia.com> * Change the attention dropout shape to align other frameworks Signed-off-by:
Reese Wang <rewang@nvidia.com> * Make encoder tests deterministic Signed-off-by:
Reese Wang <rewang@nvidia.com> * Change the default seed for the jax encoder tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Maintain offset in TE Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the resource safety Signed-off-by:
Reese Wang <rewang@nvidia.com> * Revert rng_state type to allow only i64 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Handle the corner case for elts_per_threads calculation Signed-off-by:
Reese Wang <rewang@nvidia.com> * Populate rng state by kernels Signed-off-by:
Reese Wang <rewang@nvidia.com> * Rename rng_state as seed in cpp_extensions Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update the attention dropout comment Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
-
- 09 May, 2023 1 commit
-
-
zlsh80826 authored
* Add fused attention unit tests Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use NVTE_* enums Signed-off-by:
Reese Wang <rewang@nvidia.com> * Use NVTE_Mask_Type and remove FMHADescriptor Signed-off-by:
Reese Wang <rewang@nvidia.com> * Move common functions to utils Signed-off-by:
Reese Wang <rewang@nvidia.com> * Change namespace to fused_attn Signed-off-by:
Reese Wang <rewang@nvidia.com> * Move fused_attn_max_512_fwd_qkvpacked under the general APIs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add fused_attn_max_512_bwd_qkvpacked Signed-off-by:
Reese Wang <rewang@nvidia.com> * Move fused_attn_max_512_bwd_qkvpacked under the general APIs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove redundant blank line Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix a potential bug for cu_seqlen converter Signed-off-by:
Reese Wang <rewang@nvidia.com> * Reformat fused_attn_max_512 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refine the unfused attention warning message Signed-off-by:
Reese Wang <rewang@nvidia.com> * Rename to fused_attn_max_512 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove the deprecated header Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix flax import Signed-off-by:
Reese Wang <rewang@nvidia.com> * Rename to fused attn Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add attention related mask Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add attn_mask_type and attn_bias_type Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refactor jax primitive API * Merge q_cu_seqlen and kv_cu_seqlen * Remove is_causal_masking * Replace seed with rng_state * Add is_training argument Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove dsoftmax from the customcall Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add None guard for bias and dropout_rng Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add version guard Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add is_fused_attn_kernel_available() to correctly dispatch the attention impl Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix the merge conflict Signed-off-by:
Reese Wang <rewang@nvidia.com> * Adjust the code style Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add the missing blank lines Signed-off-by:
Reese Wang <rewang@nvidia.com> * Change the order of FADescriptor members Signed-off-by:
Reese Wang <rewang@nvidia.com> * Enhance the readability of fused_attn_max_512.cu Signed-off-by:
Reese Wang <rewang@nvidia.com> * Generalize the input dimension unpacking Signed-off-by:
Reese Wang <rewang@nvidia.com> * 16 bits fused attention requires 8.9.1 Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update fused attention support matrix Signed-off-by:
Reese Wang <rewang@nvidia.com> * Handle None type when sharding Signed-off-by:
Reese Wang <rewang@nvidia.com> * Change to the padding ratio Signed-off-by:
Reese Wang <rewang@nvidia.com> * Performance optimization for non-bias cases Signed-off-by:
Reese Wang <rewang@nvidia.com> * Revert the cudnn-frontend PRIVATE keyword which was used for debugging Signed-off-by:
Reese Wang <rewang@nvidia.com> * Revert "Update fused attention support matrix" This reverts commit 4effe67d0f08f733919a329ce5ab421958740f4a. Signed-off-by:
Reese Wang <rewang@nvidia.com> * Treat b * s as total_seqs to align ragged cases Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add FP16/BF16 max_seqlen <= 512 fused attention to the support matrix Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refine test_fused_attn.py * Replace reference code with flax.linen * Remove unnecessary comments * Use AttnMaskType Signed-off-by:
Reese Wang <rewang@nvidia.com> * Unify the cuDNN compile version Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add dropout to the support matrix Signed-off-by:
Reese Wang <rewang@nvidia.com> * Slightly adjust the headers Signed-off-by:
Reese Wang <rewang@nvidia.com> * Typo fix: remove redundant either Signed-off-by:
Reese Wang <rewang@nvidia.com> * Consolidating fused attention requirements Signed-off-by:
Reese Wang <rewang@nvidia.com> * Replace cudnn_frontend::throw_if with NVTE_CHECK for the better error line report Signed-off-by:
Reese Wang <rewang@nvidia.com> * Rename to fused_attn_fp16_bf16_max_seqlen_512 for the better readability Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove CUDNN_FRONTEND_UNUSED Signed-off-by:
Reese Wang <rewang@nvidia.com> * Add more annotations to the custom calls Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
-