Unverified Commit cd9f9d63 authored by Jong-hun Shin's avatar Jong-hun Shin Committed by GitHub
Browse files

[gpt-neox] Add attention_bias config to support model trained without attention biases (#28126)

* add attention_bias hparam for a model trained without attention biases

* fix argument documentation error
parent def581ef
...@@ -86,6 +86,8 @@ class GPTNeoXConfig(PretrainedConfig): ...@@ -86,6 +86,8 @@ class GPTNeoXConfig(PretrainedConfig):
these scaling strategies behave: these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions. experimental feature, subject to breaking API changes in future versions.
attention_bias (`bool`, *optional*, defaults to `True`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
Example: Example:
...@@ -126,6 +128,7 @@ class GPTNeoXConfig(PretrainedConfig): ...@@ -126,6 +128,7 @@ class GPTNeoXConfig(PretrainedConfig):
tie_word_embeddings=False, tie_word_embeddings=False,
use_parallel_residual=True, use_parallel_residual=True,
rope_scaling=None, rope_scaling=None,
attention_bias=True,
**kwargs, **kwargs,
): ):
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
...@@ -147,6 +150,7 @@ class GPTNeoXConfig(PretrainedConfig): ...@@ -147,6 +150,7 @@ class GPTNeoXConfig(PretrainedConfig):
self.tie_word_embeddings = tie_word_embeddings self.tie_word_embeddings = tie_word_embeddings
self.use_parallel_residual = use_parallel_residual self.use_parallel_residual = use_parallel_residual
self.rope_scaling = rope_scaling self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self._rope_scaling_validation() self._rope_scaling_validation()
if self.hidden_size % self.num_attention_heads != 0: if self.hidden_size % self.num_attention_heads != 0:
......
...@@ -117,8 +117,8 @@ class GPTNeoXAttention(nn.Module): ...@@ -117,8 +117,8 @@ class GPTNeoXAttention(nn.Module):
self._init_rope() self._init_rope()
self.norm_factor = self.head_size**-0.5 self.norm_factor = self.head_size**-0.5
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size) self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.attention_bias)
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
self.attention_dropout = nn.Dropout(config.attention_dropout) self.attention_dropout = nn.Dropout(config.attention_dropout)
self.is_causal = True self.is_causal = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment