"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "01c056f09441a8670d0a88f24e2d4fb4a2956ae8"
Unverified Commit 133015f4 authored by drbh's avatar drbh Committed by GitHub
Browse files

fix: prefer original layernorm names for 180B (#2365)

parent a64d407d
...@@ -382,8 +382,13 @@ class FlashRWLayer(nn.Module): ...@@ -382,8 +382,13 @@ class FlashRWLayer(nn.Module):
prefix = f"{prefix}.h.{layer_id}" prefix = f"{prefix}.h.{layer_id}"
# NOTE: Falcon 180B uses the ln_attn prefix
ln_prefix = "input_layernorm"
if config.num_hidden_layers == 80:
ln_prefix = "ln_attn"
self.input_layernorm = FastLayerNorm.load( self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm", prefix=f"{prefix}.{ln_prefix}",
weights=weights, weights=weights,
eps=config.layer_norm_epsilon, eps=config.layer_norm_epsilon,
) )
...@@ -477,6 +482,10 @@ class FlashRWLayerNorm(nn.Module): ...@@ -477,6 +482,10 @@ class FlashRWLayerNorm(nn.Module):
# in the case no number of layer norms is provided, we default to 1 # in the case no number of layer norms is provided, we default to 1
self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1) self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1)
# Falcon 180B uses the ln_attn prefix and has 2 layer norms
if config.num_hidden_layers == 80:
self.num_ln = 2
if self.num_ln == 1: if self.num_ln == 1:
self.input_ln = FastLayerNorm.load( self.input_ln = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm", prefix=f"{prefix}.input_layernorm",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment