"tests/models/plbart/test_modeling_plbart.py" did not exist on "e983da0e7d91c100e6e35efcb8a69c8cd41d6e09"
Unverified Commit 6ba63ac3 authored by Matt's avatar Matt Committed by GitHub
Browse files

[InternLM] Add support for InternLM (#26302)

* Add config.bias to LLaMA to allow InternLM models to be ported as LLaMA checkpoints

* Rename bias -> attention_bias and add docstring
parent 0ac38750
......@@ -87,6 +87,8 @@ class LlamaConfig(PretrainedConfig):
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
attention_bias (`bool`, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
Example:
......@@ -125,6 +127,7 @@ class LlamaConfig(PretrainedConfig):
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
**kwargs,
):
self.vocab_size = vocab_size
......@@ -147,6 +150,7 @@ class LlamaConfig(PretrainedConfig):
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
self.attention_bias = attention_bias
super().__init__(
pad_token_id=pad_token_id,
......
......@@ -280,11 +280,10 @@ class LlamaAttention(nn.Module):
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
self._init_rope()
def _init_rope(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment