"examples/vscode:/vscode.git/clone" did not exist on "9f3f58c90559cc8d56c6d2d62ef9b0f1b93e123e"
Unverified Commit e3143952 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

Refactor flash attention implementation in transformers (#31446)



* dumb commit

* nit

* update

* something like this

* unpack in modeling utils

* safe import

* oups

* update

* nits

* diff convert gemma

* update

* start propagating

* udpate other modeling code as well

* update for sliding window models

* nits

* more init cleanups

* styling

* fixup

* noice

* pass fixup

* typo typing_extension -> typing_extensions

* torch.nn.functionnal -> torch.nn.functional

* add to import structure

* unpack

* simplify a bit more for this first version

* nut

* update

* update

* nit

* ease the import of `Unpack`

* remove useless `use_sliding_window`

* no qua please

* protect import?

* style

* [run-slow]

* [run slow] llama,gemma,mistral,mixtral

* remove extra kwargs

* fix llama

* address review comments

* apply diff_model_converter to modeling_gemma.py

* remove cache_position 1

* remove cache_position 2

* some cleaning

* refactor gemma2 as well

* apply review comments

* rename file to modeling_flash_attention_utils.py

* siglip refactor

* remove dead code

* is the hub down?

* still down?

* fix siglip

* fix gemma2

* fatal: Could not read from remote repository.

* fix typo in softcap implem

* flacky

* Failed: Timeout >120.0s

---------
Co-authored-by: default avatarfxmarty <9808326+fxmarty@users.noreply.github.com>
parent ad4ef3a2
...@@ -812,11 +812,11 @@ def is_flash_attn_greater_or_equal_2_10(): ...@@ -812,11 +812,11 @@ def is_flash_attn_greater_or_equal_2_10():
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0") return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
def is_flash_attn_greater_or_equal(version): def is_flash_attn_greater_or_equal(library_version: str):
if not _is_package_available("flash_attn"): if not _is_package_available("flash_attn"):
return False return False
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(version) return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version)
def is_torchdistx_available(): def is_torchdistx_available():
......
...@@ -41,6 +41,8 @@ SPECIAL_CASES_TO_ALLOW = { ...@@ -41,6 +41,8 @@ SPECIAL_CASES_TO_ALLOW = {
"expert_layer_offset", "expert_layer_offset",
"expert_layer_period", "expert_layer_period",
], ],
"Qwen2Config": ["use_sliding_window"],
"Qwen2MoeConfig": ["use_sliding_window"],
"Gemma2Config": ["tie_word_embeddings"], "Gemma2Config": ["tie_word_embeddings"],
# used to compute the property `self.chunk_length` # used to compute the property `self.chunk_length`
"EncodecConfig": ["overlap"], "EncodecConfig": ["overlap"],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment