Unverified Commit 133015f4 authored by drbh's avatar drbh Committed by GitHub
Browse files

fix: prefer original layernorm names for 180B (#2365)

parent a64d407d
......@@ -382,8 +382,13 @@ class FlashRWLayer(nn.Module):
prefix = f"{prefix}.h.{layer_id}"
# NOTE: Falcon 180B uses the ln_attn prefix
ln_prefix = "input_layernorm"
if config.num_hidden_layers == 80:
ln_prefix = "ln_attn"
self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm",
prefix=f"{prefix}.{ln_prefix}",
weights=weights,
eps=config.layer_norm_epsilon,
)
......@@ -477,6 +482,10 @@ class FlashRWLayerNorm(nn.Module):
# in the case no number of layer norms is provided, we default to 1
self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1)
# Falcon 180B uses the ln_attn prefix and has 2 layer norms
if config.num_hidden_layers == 80:
self.num_ln = 2
if self.num_ln == 1:
self.input_ln = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment