Unverified Commit d3317bbb authored by Paul Pak's avatar Paul Pak Committed by GitHub
Browse files

[Models] Lfm2Moe: minor name changes for resolving lora conflicts (#29063)


Signed-off-by: default avatarPaul Pak <paulpak58@gmail.com>
parent 8e61425e
......@@ -248,7 +248,7 @@ class Lfm2ShortConvDecoderLayer(nn.Module):
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.conv = ShortConv(
self.short_conv = ShortConv(
config=config,
dim=config.conv_dim,
layer_idx=layer_idx,
......@@ -281,7 +281,7 @@ class Lfm2ShortConvDecoderLayer(nn.Module):
else:
hidden_states, residual = self.operator_norm(hidden_states, residual)
output = torch.empty_like(hidden_states)
self.conv(
self.short_conv(
hidden_states,
output,
)
......@@ -380,6 +380,9 @@ class Lfm2Model(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if ".conv." in name:
name = name.replace(".conv.", ".short_conv.", 1)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
......@@ -414,6 +417,7 @@ class Lfm2ForCausalLM(
"w1",
"w3",
],
"in_proj": ["in_proj"],
}
# LoRA specific attributes
......
......@@ -349,7 +349,7 @@ class Lfm2MoeShortConvDecoderLayer(nn.Module):
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.conv = ShortConv(
self.short_conv = ShortConv(
config=config,
dim=config.hidden_size,
layer_idx=layer_idx,
......@@ -388,7 +388,7 @@ class Lfm2MoeShortConvDecoderLayer(nn.Module):
else:
hidden_states, residual = self.operator_norm(hidden_states, residual)
output = torch.empty_like(hidden_states)
self.conv(
self.short_conv(
hidden_states,
output,
)
......@@ -509,6 +509,9 @@ class Lfm2MoeModel(nn.Module):
if "expert_bias" in name:
name = name.replace("expert_bias", "gate.e_score_correction_bias")
if ".conv." in name:
name = name.replace(".conv.", ".short_conv.", 1)
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
......@@ -595,6 +598,7 @@ class Lfm2MoeForCausalLM(
"w1",
"w3",
],
"in_proj": ["in_proj"],
}
# LoRA specific attributes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment