Unverified Commit f4ec7a28 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`Gemma2`] Support FA2 softcapping (#31887)

* Support softcapping

* strictly greater than

* update
parent f67e0f7f
...@@ -41,6 +41,7 @@ from ...utils import ( ...@@ -41,6 +41,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10, is_flash_attn_greater_or_equal_2_10,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
...@@ -382,6 +383,7 @@ class Gemma2FlashAttention2(Gemma2Attention): ...@@ -382,6 +383,7 @@ class Gemma2FlashAttention2(Gemma2Attention):
q_len, q_len,
dropout=dropout_rate, dropout=dropout_rate,
softmax_scale=self.scaling, softmax_scale=self.scaling,
softcap=self.config.attn_logit_softcapping,
) )
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
...@@ -402,6 +404,7 @@ class Gemma2FlashAttention2(Gemma2Attention): ...@@ -402,6 +404,7 @@ class Gemma2FlashAttention2(Gemma2Attention):
dropout=0.0, dropout=0.0,
softmax_scale=None, softmax_scale=None,
cache_position=0, cache_position=0,
softcap=None,
): ):
""" """
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
...@@ -432,7 +435,9 @@ class Gemma2FlashAttention2(Gemma2Attention): ...@@ -432,7 +435,9 @@ class Gemma2FlashAttention2(Gemma2Attention):
use_sliding_windows = ( use_sliding_windows = (
_flash_supports_window_size and self.sliding_window is not None and cache_position > self.sliding_window _flash_supports_window_size and self.sliding_window is not None and cache_position > self.sliding_window
) )
flash_kwargs = {"window_size": (self.sliding_window, self.sliding_window)} if use_sliding_windows else {} flash_kwargs = {"softcap"} if is_flash_attn_greater_or_equal("2.6.0") else {}
if use_sliding_windows:
flash_kwargs.update({"window_size": (self.sliding_window, self.sliding_window)})
# Contains at least one padding token in the sequence # Contains at least one padding token in the sequence
if attention_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
......
...@@ -128,6 +128,7 @@ from .import_utils import ( ...@@ -128,6 +128,7 @@ from .import_utils import (
is_essentia_available, is_essentia_available,
is_faiss_available, is_faiss_available,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10, is_flash_attn_greater_or_equal_2_10,
is_flax_available, is_flax_available,
is_fsdp_available, is_fsdp_available,
......
...@@ -812,6 +812,13 @@ def is_flash_attn_greater_or_equal_2_10(): ...@@ -812,6 +812,13 @@ def is_flash_attn_greater_or_equal_2_10():
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0") return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
def is_flash_attn_greater_or_equal(version):
if not _is_package_available("flash_attn"):
return False
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(version)
def is_torchdistx_available(): def is_torchdistx_available():
return _torchdistx_available return _torchdistx_available
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment