Unverified Commit 4f55f158 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing baichuan override. (#2158)

parent d0225b10
...@@ -117,6 +117,11 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -117,6 +117,11 @@ class FlashLlamaAttention(torch.nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads self.head_size = self.hidden_size // self.num_heads
# Setting defaults for baichuan custom config which doesn't apply them.
config.rope_theta = getattr(config, "rope_theta", 10000)
config.num_key_value_heads = getattr(
config, "num_key_value_heads", config.num_attention_heads
)
self.rotary_emb = PositionRotaryEmbedding.static( self.rotary_emb = PositionRotaryEmbedding.static(
config=config, config=config,
dim=self.head_size, dim=self.head_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment