Unverified Commit cce475a9 authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

hotfix: Fix number of KV heads (#2202)

Fix number of KV heads
parent 521d0d99
......@@ -906,8 +906,8 @@ class FlashCausalLM(Model):
# Validation is done in the model itself
if num_kv_heads is None:
# Order is important here.
for attr in ["num_key_value_heads", "num_key_value_heads", "n_head"]:
num_kv_heads = getattr(config, "num_attention_heads", None)
for attr in ["num_key_value_heads", "num_attention_heads", "n_head"]:
num_kv_heads = getattr(config, attr, None)
if num_kv_heads is not None:
break
if num_kv_heads is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment