1. 23 Jul, 2024 6 commits
    • Tri Dao's avatar
      Clean up softcapping bwd a bit · 5ca83a9c
      Tri Dao authored
      5ca83a9c
    • Tri Dao's avatar
      751c762c
    • Driss Guessous's avatar
      Fix ima for split-kv kernel (#1085) · 1c275eb0
      Driss Guessous authored
      1c275eb0
    • janEbert's avatar
      Make FA3 externally importable (#1053) · 3c4053b7
      janEbert authored
      Library name to import is `flash_attn_interface`, which matches the
      test.
      3c4053b7
    • rocking's avatar
      Support AMD ROCm on FlashAttention 2 (#1010) · d8f104e9
      rocking authored
      
      
      * Support ck in fmha
      
      * Add ck submodule
      
      * Do not return lse if return_softmax == false
      
      * Use receipt to speed up ck compile time
      
      * Integrate new version of ck_tile
      
      * Support dropout for mha_fwd()
      
      * Add dropout to mha_varlen_fwd()
      
      * Update ck to develop
      
      * Extract padding function for dropout randval
      
      * Extract randval transformation function
      
      * Sync the code structure and coding style with FA
      
      * Remove this line, c++ api will handle this.
      Sync with test_flash_attn.py
      
      * fix compile error
      
      * Add mha_bwd
      
      * Generate dropout seed and offset from user generator
      
      * update CK
      
      * Add mha_varlen_bwd
      
      * Use same python as build flash-attn to generate ck kernel
      
      * Fix bug of group mode fwd about returning softmax lse
      
      * larger the test tollerance
      
      * Add test_flash_attn_output() and test_flash_attn_varlen_output()
      
      * Always fill softmax_lse
      
      * Remove duplicate benchmark script, since we already implement mha_bwd
      
      * Refine get value from tuple
      
      * Use default parameter for stream_config
      
      * unblock all platform
      
      * Add comment
      
      * refine the test code
      
      * Refine naming
      
      * Add unpack to namespace
      
      * Do not hardcode the warp size 64
      
      * Add more targets
      
      * Add README
      
      * Optimize mha_fwd if seqlen_q == 1
      
      * Support get_wheel_url for rocm
      
      * Detect rocm environment by pytorch's IS_HIP_EXTENSION
      
      * update to lastest ck
      
      * Add necessary compile flag
      
      * Sync the api with upstream FA
      
      ---------
      Co-authored-by: default avatarcarlushuang <carlus.huang@amd.com>
      Co-authored-by: default avatarYichen Yan <wenji.yyc@alibaba-inc.com>
      Co-authored-by: default avatarPo Yen Chen <PoYen.Chen@amd.com>
      Co-authored-by: default avatarYichen Yan <oraluben@outlook.com>
      d8f104e9
    • Ying Zhang's avatar
      Add var-seq-len to FA3 fp16 / bf16 fwd (#1072) · dfe1a59e
      Ying Zhang authored
      
      
      * fwd var-seq-len
      
      * fixes
      
      * benchmark
      
      * fixes
      
      ---------
      Co-authored-by: default avatarTri Dao <tridao@users.noreply.github.com>
      dfe1a59e
  2. 22 Jul, 2024 4 commits
    • Cameron Shinn's avatar
      cb516f85
    • Phil Wang's avatar
      backwards for softcapping (#1033) · 5f1ae4a3
      Phil Wang authored
      * check in the two ways of approaching backwards for softcapping, both functional
      
      * prepare the softcap switch for backwards
      
      * temporary
      
      * cleanup to the way Tri prefers
      
      * calculate dtanh when copying from scores -> dtanh Tensor
      
      * no ternary operators allowed for constexpr, so just use some hack found online
      
      * fix maybe_dtanh, restore some files
      
      * restore another file
      
      * move calculate_dtanh to utils and colocate with apply_softcap
      
      * cleanup
      
      * maybe last cleanup
      
      * save for another pr
      
      * remove a stray line
      
      * fix spacing
      
      * fix an issue, and make test_flash_attn.py ready to test softcapping backwards
      5f1ae4a3
    • youkaichao's avatar
      remove lambda (#1056) · ef3e358a
      youkaichao authored
      ef3e358a
    • Jorge António's avatar
      catch typo (#1058) · 4df62e14
      Jorge António authored
      4df62e14
  3. 15 Jul, 2024 1 commit
  4. 13 Jul, 2024 1 commit
  5. 11 Jul, 2024 8 commits
  6. 10 Jul, 2024 8 commits
  7. 09 Jul, 2024 1 commit
  8. 08 Jul, 2024 2 commits
  9. 03 Jul, 2024 1 commit
  10. 01 Jul, 2024 5 commits
    • 66RING's avatar
      Fix typos of comments about shape. (#837) · 9486635c
      66RING authored
      9486635c
    • JDKWangGuan's avatar
      Fix KeyError handling for non-existing key in state_dict.pop() (#898) · 0d810cfb
      JDKWangGuan authored
      Update handling for KeyError in state_dict.pop() for non-existing keys.
      Changed state_dict.pop(f"h.{d}.attn.bias") to state_dict.pop(f"h.{d}.attn.bias", None) to prevent KeyError exceptions.
      
      
      The following code can re-produce the issue
      ```
      from transformers import AutoTokenizer, GPT2Model, GPT2Config
      from flash_attn.models.gpt import GPTLMHeadModel, GPTModel
      
      # >>> transformers.__version__
      # '4.38.2'
      
      model_path = 'gpt2'
      output_model_path = 'gpt2_model'
      config = GPT2Config.from_pretrained(model_path, output_hidden_states=True)
      model = GPT2Model.from_pretrained(model_path, from_tf=False, config=config)
      '''
      model fine-tuning here
      '''
      # dump the fine-tuned model
      model.save_pretrained(output_model_path)
      
      # load the fine-tuned model
      config = GPT2Config.from_pretrained(output_model_path, output_hidden_states=True)
      model = GPTModel.from_pretrained(output_model_path, config=config, strict=True)  # failed due to KeyError: 'h.0.attn.bias'
      model = GPTLMHeadModel.from_pretrained(output_model_path, config=config, strict=True)  # failed due to KeyError: 'h.0.attn.bias'
      
      ```
      0d810cfb
    • cao lei's avatar
      fix typo (#974) · 6a2a16e9
      cao lei authored
      6a2a16e9
    • Nicolas Patry's avatar
      Fixing argument checking when using `seqlenq_ngroups_swapped`. (#976) · 5bf20196
      Nicolas Patry authored
      When user send `out` as a parameter of the function
      `seqlenq_ngroups_swapped` with parameters that trigger,
      the CHECK_SHAPE is incorrect (since q shape is modified.)
      5bf20196
    • Liang's avatar
  11. 27 Jun, 2024 1 commit
  12. 26 May, 2024 2 commits