Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bbf1e618
Unverified
Commit
bbf1e618
authored
Jun 28, 2024
by
Arthur
Committed by
GitHub
Jun 28, 2024
Browse files
Gemma capping is a must for big models (#31698)
* softcapping * soft cap before the mask * style * ... * super nit
parent
cb298978
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
0 deletions
+8
-0
src/transformers/models/gemma2/configuration_gemma2.py
src/transformers/models/gemma2/configuration_gemma2.py
+3
-0
src/transformers/models/gemma2/modeling_gemma2.py
src/transformers/models/gemma2/modeling_gemma2.py
+5
-0
No files found.
src/transformers/models/gemma2/configuration_gemma2.py
View file @
bbf1e618
...
...
@@ -78,6 +78,7 @@ class Gemma2Config(PretrainedConfig):
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores
sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the
size of the sliding window.
...
...
@@ -116,6 +117,7 @@ class Gemma2Config(PretrainedConfig):
attention_bias
=
False
,
attention_dropout
=
0.0
,
final_logit_softcapping
=
30.0
,
attn_logit_softcapping
=
50.0
,
query_pre_attn_scalar
=
224
,
sliding_window
=
4096
,
**
kwargs
,
...
...
@@ -135,6 +137,7 @@ class Gemma2Config(PretrainedConfig):
self
.
rope_theta
=
rope_theta
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
self
.
attn_logit_softcapping
=
attn_logit_softcapping
super
().
__init__
(
pad_token_id
=
pad_token_id
,
...
...
src/transformers/models/gemma2/modeling_gemma2.py
View file @
bbf1e618
...
...
@@ -256,6 +256,11 @@ class Gemma2Attention(nn.Module):
attn_weights
=
torch
.
matmul
(
query_states
,
key_states
.
transpose
(
2
,
3
))
*
self
.
scaling
if
self
.
config
.
attn_logit_softcapping
is
not
None
:
attn_weights
=
attn_weights
/
self
.
config
.
attn_logit_softcapping
attn_weights
=
torch
.
tanh
(
attn_weights
)
attn_weights
=
attn_weights
*
self
.
config
.
attn_logit_softcapping
if
attention_mask
is
not
None
:
# no matter the length, we just slice it
causal_mask
=
attention_mask
[:,
:,
:,
:
key_states
.
shape
[
-
2
]]
attn_weights
=
attn_weights
+
causal_mask
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment