Commit 80eea43e authored by Casper Hansen's avatar Casper Hansen
Browse files

Fix Falcon n_kv_heads parameter

parent d016c513
......@@ -38,7 +38,7 @@ class FalconDecoderLayer(nn.Module):
input_layernorm=None, ln_attn=None, ln_mlp=None, new_decoder_arch=True):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = 8
self.n_kv_heads = 8 if new_decoder_arch else 0
self.hidden_size = hidden_size
self.new_decoder_arch = new_decoder_arch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment