Unverified Commit bd6e8b3c authored by drbh's avatar drbh Committed by GitHub
Browse files

fix: adjust llama MLP name from dense to mlp to correctly apply lora (#2760)

parent 5489406c
...@@ -422,7 +422,7 @@ class FlashLlamaLayer(nn.Module): ...@@ -422,7 +422,7 @@ class FlashLlamaLayer(nn.Module):
if SparseMoELayer.is_supported(weights) if SparseMoELayer.is_supported(weights)
else DenseMoELayer else DenseMoELayer
) )
self.dense = Phi3MoE( self.mlp = Phi3MoE(
f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights
) )
# with moe the layernorms are are not rmsnorms and they have bias # with moe the layernorms are are not rmsnorms and they have bias
...@@ -437,7 +437,7 @@ class FlashLlamaLayer(nn.Module): ...@@ -437,7 +437,7 @@ class FlashLlamaLayer(nn.Module):
eps=config.rms_norm_eps, eps=config.rms_norm_eps,
) )
else: else:
self.dense = LlamaMLP( self.mlp = LlamaMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
) )
self.input_layernorm = FastRMSNorm.load( self.input_layernorm = FastRMSNorm.load(
...@@ -493,7 +493,7 @@ class FlashLlamaLayer(nn.Module): ...@@ -493,7 +493,7 @@ class FlashLlamaLayer(nn.Module):
attn_output, res attn_output, res
) )
mlp_output = self.dense(normed_attn_res_output, adapter_data) mlp_output = self.mlp(normed_attn_res_output, adapter_data)
if self.residual_multiplier is not None: if self.residual_multiplier is not None:
mlp_output *= self.residual_multiplier mlp_output *= self.residual_multiplier
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment