Unverified Commit cce475a9 authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

hotfix: Fix number of KV heads (#2202)

Fix number of KV heads
parent 521d0d99
...@@ -906,8 +906,8 @@ class FlashCausalLM(Model): ...@@ -906,8 +906,8 @@ class FlashCausalLM(Model):
# Validation is done in the model itself # Validation is done in the model itself
if num_kv_heads is None: if num_kv_heads is None:
# Order is important here. # Order is important here.
for attr in ["num_key_value_heads", "num_key_value_heads", "n_head"]: for attr in ["num_key_value_heads", "num_attention_heads", "n_head"]:
num_kv_heads = getattr(config, "num_attention_heads", None) num_kv_heads = getattr(config, attr, None)
if num_kv_heads is not None: if num_kv_heads is not None:
break break
if num_kv_heads is None: if num_kv_heads is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment