Unverified Commit 908a7134 authored by ℍ𝕠𝕝𝕝𝕠𝕨 𝕄𝕒𝕟's avatar ℍ𝕠𝕝𝕝𝕠𝕨 𝕄𝕒𝕟 Committed by GitHub
Browse files

[Bugfix] LoRA: extend expert base_layer loading to Qwen3.5 and Step3.x (#37114)


Signed-off-by: default avatarHollow Man <hollowman@opensuse.org>
parent ec5ef0ac
...@@ -306,9 +306,12 @@ class Qwen3_5Model(Qwen3NextModel): ...@@ -306,9 +306,12 @@ class Qwen3_5Model(Qwen3NextModel):
loaded_params: set[str] = set() loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping() expert_params_mapping = self.get_expert_mapping()
is_fused_expert = False is_fused_expert = False
base_layer = (
"base_layer." if any(".base_layer." in name for name in params_dict) else ""
)
fused_expert_params_mapping = [ fused_expert_params_mapping = [
("experts.w13_weight", "experts.gate_up_proj", 0, "w1"), (f"experts.{base_layer}w13_weight", "experts.gate_up_proj", 0, "w1"),
("experts.w2_weight", "experts.down_proj", 0, "w2"), (f"experts.{base_layer}w2_weight", "experts.down_proj", 0, "w2"),
] ]
num_experts = ( num_experts = (
self.config.num_experts if hasattr(self.config, "num_experts") else 0 self.config.num_experts if hasattr(self.config, "num_experts") else 0
......
...@@ -207,9 +207,12 @@ class Qwen3_5MultiTokenPredictor(nn.Module): ...@@ -207,9 +207,12 @@ class Qwen3_5MultiTokenPredictor(nn.Module):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: set[str] = set() loaded_params: set[str] = set()
is_fused_expert = False is_fused_expert = False
base_layer = (
"base_layer." if any(".base_layer." in name for name in params_dict) else ""
)
fused_expert_params_mapping = [ fused_expert_params_mapping = [
("experts.w13_weight", "experts.gate_up_proj", 0, "w1"), (f"experts.{base_layer}w13_weight", "experts.gate_up_proj", 0, "w1"),
("experts.w2_weight", "experts.down_proj", 0, "w2"), (f"experts.{base_layer}w2_weight", "experts.down_proj", 0, "w2"),
] ]
num_experts = ( num_experts = (
self.config.num_experts if hasattr(self.config, "num_experts") else 0 self.config.num_experts if hasattr(self.config, "num_experts") else 0
......
...@@ -183,9 +183,12 @@ class Qwen3MoeLLMModel(Qwen3MoeModel): ...@@ -183,9 +183,12 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
loaded_params: set[str] = set() loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping() expert_params_mapping = self.get_expert_mapping()
is_fused_expert = False is_fused_expert = False
base_layer = (
"base_layer." if any(".base_layer." in name for name in params_dict) else ""
)
fused_expert_params_mapping = [ fused_expert_params_mapping = [
("experts.w13_weight", "experts.gate_up_proj", 0, "w1"), (f"experts.{base_layer}w13_weight", "experts.gate_up_proj", 0, "w1"),
("experts.w2_weight", "experts.down_proj", 0, "w2"), (f"experts.{base_layer}w2_weight", "experts.down_proj", 0, "w2"),
] ]
num_experts = self.config.num_experts num_experts = self.config.num_experts
for name, loaded_weight in weights: for name, loaded_weight in weights:
......
...@@ -463,11 +463,14 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): ...@@ -463,11 +463,14 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: set[str] = set() loaded_params: set[str] = set()
base_layer = (
"base_layer." if any(".base_layer." in name for name in params_dict) else ""
)
expert_params_mapping = [ expert_params_mapping = [
(".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"), (f".moe.experts.{base_layer}w13_weight", ".moe.gate_proj.weight", "w1"),
(".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"), (f".moe.experts.{base_layer}w13_weight", ".moe.up_proj.weight", "w3"),
(".moe.experts.w2_weight", ".moe.down_proj.weight", "w2"), (f".moe.experts.{base_layer}w2_weight", ".moe.down_proj.weight", "w2"),
] ]
disable_moe_stacked_params = [data[1] for data in expert_params_mapping] disable_moe_stacked_params = [data[1] for data in expert_params_mapping]
......
...@@ -626,12 +626,15 @@ class Step3p5Model(nn.Module): ...@@ -626,12 +626,15 @@ class Step3p5Model(nn.Module):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: set[str] = set() loaded_params: set[str] = set()
base_layer = (
"base_layer." if any(".base_layer." in name for name in params_dict) else ""
)
# Old packed 3D format: .moe.gate_proj.weight [num_experts, out, in] # Old packed 3D format: .moe.gate_proj.weight [num_experts, out, in]
expert_params_mapping = [ expert_params_mapping = [
(".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"), (f".moe.experts.{base_layer}w13_weight", ".moe.gate_proj.weight", "w1"),
(".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"), (f".moe.experts.{base_layer}w13_weight", ".moe.up_proj.weight", "w3"),
(".moe.experts.w2_weight", ".moe.down_proj.weight", "w2"), (f".moe.experts.{base_layer}w2_weight", ".moe.down_proj.weight", "w2"),
] ]
# New per-expert format: .moe.experts.E.gate_proj.weight_packed [out, in] # New per-expert format: .moe.experts.E.gate_proj.weight_packed [out, in]
......
...@@ -181,14 +181,17 @@ class Step3p5MTP(nn.Module): ...@@ -181,14 +181,17 @@ class Step3p5MTP(nn.Module):
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters())
base_layer = (
"base_layer." if any(".base_layer." in name for name in params_dict) else ""
)
expert_params_mapping = [ expert_params_mapping = [
(".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"), (f".moe.experts.{base_layer}w13_weight", ".moe.gate_proj.weight", "w1"),
(".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"), (f".moe.experts.{base_layer}w13_weight", ".moe.up_proj.weight", "w3"),
(".moe.experts.w2_weight", ".moe.down_proj.weight", "w2"), (f".moe.experts.{base_layer}w2_weight", ".moe.down_proj.weight", "w2"),
] ]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment