Unverified Commit 3961e323 authored by shaltielshmid's avatar shaltielshmid Committed by GitHub
Browse files

[WIP] Add support for Mistral-Nemo by supporting head_dim through config (#2254)



* Support passing head_dim through config

* Using `head_dim` as a fallback is necessary since it's a non standard
key in mistralConfig (as defined in transformers).

* Shorter diff.

---------
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
parent 9935720c
...@@ -149,15 +149,14 @@ class MistralAttention(torch.nn.Module): ...@@ -149,15 +149,14 @@ class MistralAttention(torch.nn.Module):
bias=False, bias=False,
) )
head_size = config.hidden_size // config.num_attention_heads
self.query_key_value = TensorParallelMultiAdapterLinear.load( self.query_key_value = TensorParallelMultiAdapterLinear.load(
query_key_value, query_key_value,
layer_id, layer_id,
["q_proj", "k_proj", "v_proj"], ["q_proj", "k_proj", "v_proj"],
sizes=[ sizes=[
head_size * config.num_attention_heads, self.head_size * config.num_attention_heads,
head_size * config.num_key_value_heads, self.head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads, self.head_size * config.num_key_value_heads,
], ],
process_group=weights.process_group, process_group=weights.process_group,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment