Unverified Commit a64d407d authored by drbh's avatar drbh Committed by GitHub
Browse files

fix: default num_ln_in_parallel_attn to one if not supplied (#2364)

parent 1768c00b
...@@ -473,7 +473,9 @@ class FlashRWLayer(nn.Module): ...@@ -473,7 +473,9 @@ class FlashRWLayer(nn.Module):
class FlashRWLayerNorm(nn.Module): class FlashRWLayerNorm(nn.Module):
def __init__(self, config, prefix: str, weights): def __init__(self, config, prefix: str, weights):
super().__init__() super().__init__()
self.num_ln = config.num_ln_in_parallel_attn # Falcon2 includes the number of layer norms in the config
# in the case no number of layer norms is provided, we default to 1
self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1)
if self.num_ln == 1: if self.num_ln == 1:
self.input_ln = FastLayerNorm.load( self.input_ln = FastLayerNorm.load(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment