Unverified Commit bbf1e618 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

Gemma capping is a must for big models (#31698)

* softcapping

* soft cap before the mask

* style

* ...

* super nit
parent cb298978
...@@ -78,6 +78,7 @@ class Gemma2Config(PretrainedConfig): ...@@ -78,6 +78,7 @@ class Gemma2Config(PretrainedConfig):
attention_dropout (`float`, *optional*, defaults to 0.0): attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities. The dropout ratio for the attention probabilities.
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits. final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores
sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the
size of the sliding window. size of the sliding window.
...@@ -116,6 +117,7 @@ class Gemma2Config(PretrainedConfig): ...@@ -116,6 +117,7 @@ class Gemma2Config(PretrainedConfig):
attention_bias=False, attention_bias=False,
attention_dropout=0.0, attention_dropout=0.0,
final_logit_softcapping=30.0, final_logit_softcapping=30.0,
attn_logit_softcapping=50.0,
query_pre_attn_scalar=224, query_pre_attn_scalar=224,
sliding_window=4096, sliding_window=4096,
**kwargs, **kwargs,
...@@ -135,6 +137,7 @@ class Gemma2Config(PretrainedConfig): ...@@ -135,6 +137,7 @@ class Gemma2Config(PretrainedConfig):
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.attention_bias = attention_bias self.attention_bias = attention_bias
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.attn_logit_softcapping = attn_logit_softcapping
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
......
...@@ -256,6 +256,11 @@ class Gemma2Attention(nn.Module): ...@@ -256,6 +256,11 @@ class Gemma2Attention(nn.Module):
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
if self.config.attn_logit_softcapping is not None:
attn_weights = attn_weights / self.config.attn_logit_softcapping
attn_weights = torch.tanh(attn_weights)
attn_weights = attn_weights * self.config.attn_logit_softcapping
if attention_mask is not None: # no matter the length, we just slice it if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask attn_weights = attn_weights + causal_mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment