Unverified Commit d3317bbb authored by Paul Pak's avatar Paul Pak Committed by GitHub
Browse files

[Models] Lfm2Moe: minor name changes for resolving lora conflicts (#29063)


Signed-off-by: default avatarPaul Pak <paulpak58@gmail.com>
parent 8e61425e
...@@ -248,7 +248,7 @@ class Lfm2ShortConvDecoderLayer(nn.Module): ...@@ -248,7 +248,7 @@ class Lfm2ShortConvDecoderLayer(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.conv = ShortConv( self.short_conv = ShortConv(
config=config, config=config,
dim=config.conv_dim, dim=config.conv_dim,
layer_idx=layer_idx, layer_idx=layer_idx,
...@@ -281,7 +281,7 @@ class Lfm2ShortConvDecoderLayer(nn.Module): ...@@ -281,7 +281,7 @@ class Lfm2ShortConvDecoderLayer(nn.Module):
else: else:
hidden_states, residual = self.operator_norm(hidden_states, residual) hidden_states, residual = self.operator_norm(hidden_states, residual)
output = torch.empty_like(hidden_states) output = torch.empty_like(hidden_states)
self.conv( self.short_conv(
hidden_states, hidden_states,
output, output,
) )
...@@ -380,6 +380,9 @@ class Lfm2Model(nn.Module): ...@@ -380,6 +380,9 @@ class Lfm2Model(nn.Module):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if ".conv." in name:
name = name.replace(".conv.", ".short_conv.", 1)
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
...@@ -414,6 +417,7 @@ class Lfm2ForCausalLM( ...@@ -414,6 +417,7 @@ class Lfm2ForCausalLM(
"w1", "w1",
"w3", "w3",
], ],
"in_proj": ["in_proj"],
} }
# LoRA specific attributes # LoRA specific attributes
......
...@@ -349,7 +349,7 @@ class Lfm2MoeShortConvDecoderLayer(nn.Module): ...@@ -349,7 +349,7 @@ class Lfm2MoeShortConvDecoderLayer(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.conv = ShortConv( self.short_conv = ShortConv(
config=config, config=config,
dim=config.hidden_size, dim=config.hidden_size,
layer_idx=layer_idx, layer_idx=layer_idx,
...@@ -388,7 +388,7 @@ class Lfm2MoeShortConvDecoderLayer(nn.Module): ...@@ -388,7 +388,7 @@ class Lfm2MoeShortConvDecoderLayer(nn.Module):
else: else:
hidden_states, residual = self.operator_norm(hidden_states, residual) hidden_states, residual = self.operator_norm(hidden_states, residual)
output = torch.empty_like(hidden_states) output = torch.empty_like(hidden_states)
self.conv( self.short_conv(
hidden_states, hidden_states,
output, output,
) )
...@@ -509,6 +509,9 @@ class Lfm2MoeModel(nn.Module): ...@@ -509,6 +509,9 @@ class Lfm2MoeModel(nn.Module):
if "expert_bias" in name: if "expert_bias" in name:
name = name.replace("expert_bias", "gate.e_score_correction_bias") name = name.replace("expert_bias", "gate.e_score_correction_bias")
if ".conv." in name:
name = name.replace(".conv.", ".short_conv.", 1)
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below). # Skip non-stacked layers and experts (experts handled below).
if weight_name not in name: if weight_name not in name:
...@@ -595,6 +598,7 @@ class Lfm2MoeForCausalLM( ...@@ -595,6 +598,7 @@ class Lfm2MoeForCausalLM(
"w1", "w1",
"w3", "w3",
], ],
"in_proj": ["in_proj"],
} }
# LoRA specific attributes # LoRA specific attributes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment