Unverified Commit 8e08acad authored by Benjamin Minixhofer's avatar Benjamin Minixhofer Committed by GitHub
Browse files

Support `num_attention_heads` != `num_key_value_heads` in Flax Llama Implementation (#29557)

* fix tinyllama flax modelling

* rename vars to minimize changes

* move

* formatting

* remove unused var
parent f01e1609
...@@ -198,24 +198,32 @@ class FlaxLlamaAttention(nn.Module): ...@@ -198,24 +198,32 @@ class FlaxLlamaAttention(nn.Module):
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads self.head_dim = self.embed_dim // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.attention_softmax_in_fp32 = self.dtype is not jnp.float32 self.attention_softmax_in_fp32 = self.dtype is not jnp.float32
dense = partial( dense = partial(
nn.Dense, nn.Dense,
self.embed_dim,
use_bias=config.attention_bias, use_bias=config.attention_bias,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() self.q_proj = dense(self.num_heads * self.head_dim)
self.o_proj = dense() self.k_proj = dense(self.num_key_value_heads * self.head_dim)
self.v_proj = dense(self.num_key_value_heads * self.head_dim)
self.o_proj = dense(self.embed_dim)
if (self.head_dim * self.num_heads) != self.embed_dim:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.embed_dim}"
f" and `num_heads`: {self.num_heads})."
)
self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool") self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
self.rotary_emb = FlaxLlamaRotaryEmbedding(config, dtype=self.dtype) self.rotary_emb = FlaxLlamaRotaryEmbedding(config, dtype=self.dtype)
def _split_heads(self, hidden_states): def _split_heads(self, hidden_states, num_heads):
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))
def _merge_heads(self, hidden_states): def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
...@@ -266,9 +274,9 @@ class FlaxLlamaAttention(nn.Module): ...@@ -266,9 +274,9 @@ class FlaxLlamaAttention(nn.Module):
key = self.k_proj(hidden_states) key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states) value = self.v_proj(hidden_states)
query = self._split_heads(query) query = self._split_heads(query, self.num_heads)
key = self._split_heads(key) key = self._split_heads(key, self.num_key_value_heads)
value = self._split_heads(value) value = self._split_heads(value, self.num_key_value_heads)
key, query = self.rotary_emb(key, query, position_ids) key, query = self.rotary_emb(key, query, position_ids)
...@@ -298,6 +306,9 @@ class FlaxLlamaAttention(nn.Module): ...@@ -298,6 +306,9 @@ class FlaxLlamaAttention(nn.Module):
if self.has_variable("cache", "cached_key") or init_cache: if self.has_variable("cache", "cached_key") or init_cache:
key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask) key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
key = jnp.repeat(key, self.num_key_value_groups, axis=2)
value = jnp.repeat(value, self.num_key_value_groups, axis=2)
# transform boolean mask into float mask # transform boolean mask into float mask
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment