Unverified Commit f0a30a06 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Bugfix] Fix qwen-moe packed_modules_mapping (#26634)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 9d6cff3e
...@@ -325,7 +325,7 @@ class SupportsLoRA(Protocol): ...@@ -325,7 +325,7 @@ class SupportsLoRA(Protocol):
# are empty by default. # are empty by default.
embedding_modules: ClassVar[dict[str, str]] = {} embedding_modules: ClassVar[dict[str, str]] = {}
embedding_padding_modules: ClassVar[list[str]] = [] embedding_padding_modules: ClassVar[list[str]] = []
packed_modules_mapping: ClassVar[dict[str, list[str]]] = {} packed_modules_mapping: dict[str, list[str]] = {}
# We can't use runtime_checkable with ClassVar for issubclass checks # We can't use runtime_checkable with ClassVar for issubclass checks
......
...@@ -534,11 +534,7 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -534,11 +534,7 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
"q_proj", "q_proj",
"k_proj", "k_proj",
"v_proj", "v_proj",
], ]
"gate_up_proj": [
"gate_proj",
"up_proj",
],
} }
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
...@@ -547,6 +543,18 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -547,6 +543,18 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
# Only perform the following mapping when Qwen2MoeMLP exists
if (
getattr(config, "mlp_only_layers", [])
or config.shared_expert_intermediate_size > 0
):
self.packed_modules_mapping["gate_up_proj"] = (
[
"gate_proj",
"up_proj",
],
)
self.model = Qwen2MoeModel( self.model = Qwen2MoeModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
) )
......
...@@ -634,11 +634,7 @@ class Qwen3MoeForCausalLM( ...@@ -634,11 +634,7 @@ class Qwen3MoeForCausalLM(
"q_proj", "q_proj",
"k_proj", "k_proj",
"v_proj", "v_proj",
], ]
"gate_up_proj": [
"gate_proj",
"up_proj",
],
} }
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
...@@ -649,6 +645,14 @@ class Qwen3MoeForCausalLM( ...@@ -649,6 +645,14 @@ class Qwen3MoeForCausalLM(
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
# Only perform the following mapping when Qwen3MoeMLP exists
if getattr(config, "mlp_only_layers", []):
self.packed_modules_mapping["gate_up_proj"] = (
[
"gate_proj",
"up_proj",
],
)
self.model = Qwen3MoeModel( self.model = Qwen3MoeModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment