Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f4ec7a28
Unverified
Commit
f4ec7a28
authored
Jul 11, 2024
by
Arthur
Committed by
GitHub
Jul 11, 2024
Browse files
[`Gemma2`] Support FA2 softcapping (#31887)
* Support softcapping * strictly greater than * update
parent
f67e0f7f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
1 deletion
+14
-1
src/transformers/models/gemma2/modeling_gemma2.py
src/transformers/models/gemma2/modeling_gemma2.py
+6
-1
src/transformers/utils/__init__.py
src/transformers/utils/__init__.py
+1
-0
src/transformers/utils/import_utils.py
src/transformers/utils/import_utils.py
+7
-0
No files found.
src/transformers/models/gemma2/modeling_gemma2.py
View file @
f4ec7a28
...
@@ -41,6 +41,7 @@ from ...utils import (
...
@@ -41,6 +41,7 @@ from ...utils import (
add_start_docstrings
,
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
add_start_docstrings_to_model_forward
,
is_flash_attn_2_available
,
is_flash_attn_2_available
,
is_flash_attn_greater_or_equal
,
is_flash_attn_greater_or_equal_2_10
,
is_flash_attn_greater_or_equal_2_10
,
logging
,
logging
,
replace_return_docstrings
,
replace_return_docstrings
,
...
@@ -382,6 +383,7 @@ class Gemma2FlashAttention2(Gemma2Attention):
...
@@ -382,6 +383,7 @@ class Gemma2FlashAttention2(Gemma2Attention):
q_len
,
q_len
,
dropout
=
dropout_rate
,
dropout
=
dropout_rate
,
softmax_scale
=
self
.
scaling
,
softmax_scale
=
self
.
scaling
,
softcap
=
self
.
config
.
attn_logit_softcapping
,
)
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
-
1
).
contiguous
()
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
-
1
).
contiguous
()
...
@@ -402,6 +404,7 @@ class Gemma2FlashAttention2(Gemma2Attention):
...
@@ -402,6 +404,7 @@ class Gemma2FlashAttention2(Gemma2Attention):
dropout
=
0.0
,
dropout
=
0.0
,
softmax_scale
=
None
,
softmax_scale
=
None
,
cache_position
=
0
,
cache_position
=
0
,
softcap
=
None
,
):
):
"""
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
...
@@ -432,7 +435,9 @@ class Gemma2FlashAttention2(Gemma2Attention):
...
@@ -432,7 +435,9 @@ class Gemma2FlashAttention2(Gemma2Attention):
use_sliding_windows
=
(
use_sliding_windows
=
(
_flash_supports_window_size
and
self
.
sliding_window
is
not
None
and
cache_position
>
self
.
sliding_window
_flash_supports_window_size
and
self
.
sliding_window
is
not
None
and
cache_position
>
self
.
sliding_window
)
)
flash_kwargs
=
{
"window_size"
:
(
self
.
sliding_window
,
self
.
sliding_window
)}
if
use_sliding_windows
else
{}
flash_kwargs
=
{
"softcap"
}
if
is_flash_attn_greater_or_equal
(
"2.6.0"
)
else
{}
if
use_sliding_windows
:
flash_kwargs
.
update
({
"window_size"
:
(
self
.
sliding_window
,
self
.
sliding_window
)})
# Contains at least one padding token in the sequence
# Contains at least one padding token in the sequence
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
batch_size
=
query_states
.
shape
[
0
]
batch_size
=
query_states
.
shape
[
0
]
...
...
src/transformers/utils/__init__.py
View file @
f4ec7a28
...
@@ -128,6 +128,7 @@ from .import_utils import (
...
@@ -128,6 +128,7 @@ from .import_utils import (
is_essentia_available
,
is_essentia_available
,
is_faiss_available
,
is_faiss_available
,
is_flash_attn_2_available
,
is_flash_attn_2_available
,
is_flash_attn_greater_or_equal
,
is_flash_attn_greater_or_equal_2_10
,
is_flash_attn_greater_or_equal_2_10
,
is_flax_available
,
is_flax_available
,
is_fsdp_available
,
is_fsdp_available
,
...
...
src/transformers/utils/import_utils.py
View file @
f4ec7a28
...
@@ -812,6 +812,13 @@ def is_flash_attn_greater_or_equal_2_10():
...
@@ -812,6 +812,13 @@ def is_flash_attn_greater_or_equal_2_10():
return
version
.
parse
(
importlib
.
metadata
.
version
(
"flash_attn"
))
>=
version
.
parse
(
"2.1.0"
)
return
version
.
parse
(
importlib
.
metadata
.
version
(
"flash_attn"
))
>=
version
.
parse
(
"2.1.0"
)
def
is_flash_attn_greater_or_equal
(
version
):
if
not
_is_package_available
(
"flash_attn"
):
return
False
return
version
.
parse
(
importlib
.
metadata
.
version
(
"flash_attn"
))
>=
version
.
parse
(
version
)
def
is_torchdistx_available
():
def
is_torchdistx_available
():
return
_torchdistx_available
return
_torchdistx_available
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment