Unverified Commit 4c040aba authored by Joshua Lochner's avatar Joshua Lochner Committed by GitHub
Browse files

[mistral] Support passing `head_dim` through config (and do not require...

[mistral] Support passing `head_dim` through config (and do not require `head_dim * num_heads == hidden_size`) (#32050)

* Allow `head_dim` to be set in Mistral config

* Add docstring

* Do not require `head_dim * num_heads == hidden_size`

* [run-slow] mistral
parent c50e0551
...@@ -53,6 +53,8 @@ class MistralConfig(PretrainedConfig): ...@@ -53,6 +53,8 @@ class MistralConfig(PretrainedConfig):
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
The attention head dimension.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder. The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to `4096*32`): max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
...@@ -104,6 +106,7 @@ class MistralConfig(PretrainedConfig): ...@@ -104,6 +106,7 @@ class MistralConfig(PretrainedConfig):
num_hidden_layers=32, num_hidden_layers=32,
num_attention_heads=32, num_attention_heads=32,
num_key_value_heads=8, num_key_value_heads=8,
head_dim=None,
hidden_act="silu", hidden_act="silu",
max_position_embeddings=4096 * 32, max_position_embeddings=4096 * 32,
initializer_range=0.02, initializer_range=0.02,
...@@ -125,6 +128,7 @@ class MistralConfig(PretrainedConfig): ...@@ -125,6 +128,7 @@ class MistralConfig(PretrainedConfig):
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window self.sliding_window = sliding_window
self.head_dim = head_dim or hidden_size // num_attention_heads
# for backward compatibility # for backward compatibility
if num_key_value_heads is None: if num_key_value_heads is None:
......
...@@ -185,22 +185,17 @@ class MistralAttention(nn.Module): ...@@ -185,22 +185,17 @@ class MistralAttention(nn.Module):
self.attention_dropout = config.attention_dropout self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads self.head_dim = config.head_dim
self.num_key_value_heads = config.num_key_value_heads self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta self.rope_theta = config.rope_theta
self.is_causal = True self.is_causal = True
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.rotary_emb = MistralRotaryEmbedding( self.rotary_emb = MistralRotaryEmbedding(
self.head_dim, self.head_dim,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment