Unverified Commit 1c2c1eb8 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[MoE Refactor] Rename FusedMoE.make_expert_params_mapping to...


[MoE Refactor] Rename FusedMoE.make_expert_params_mapping to fused_moe_make_expert_params_mapping (#40671)
Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent 8824f50f
...@@ -44,7 +44,10 @@ from vllm.model_executor.layers.attention import ( ...@@ -44,7 +44,10 @@ from vllm.model_executor.layers.attention import (
Attention, Attention,
StaticSinkAttention, StaticSinkAttention,
) )
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import (
FusedMoE,
fused_moe_make_expert_params_mapping,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
...@@ -1149,7 +1152,7 @@ class OpenPanguModel(nn.Module): ...@@ -1149,7 +1152,7 @@ class OpenPanguModel(nn.Module):
] ]
has_experts = hasattr(self.config, "n_routed_experts") has_experts = hasattr(self.config, "n_routed_experts")
if has_experts: if has_experts:
expert_merge_mapping = FusedMoE.make_expert_params_mapping( expert_merge_mapping = fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -28,7 +28,9 @@ from vllm.config import VllmConfig ...@@ -28,7 +28,9 @@ from vllm.config import VllmConfig
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import (
fused_moe_make_expert_params_mapping,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -147,7 +149,7 @@ class OpenPanguMTP(nn.Module): ...@@ -147,7 +149,7 @@ class OpenPanguMTP(nn.Module):
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
] ]
expert_params_mapping = FusedMoE.make_expert_params_mapping( expert_params_mapping = fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -32,7 +32,10 @@ from vllm.distributed import ( ...@@ -32,7 +32,10 @@ from vllm.distributed import (
) )
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import (
FusedMoE,
fused_moe_make_expert_params_mapping,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -690,7 +693,7 @@ class Param2MoEModel(nn.Module): ...@@ -690,7 +693,7 @@ class Param2MoEModel(nn.Module):
return loaded_params return loaded_params
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return FusedMoE.make_expert_params_mapping( return fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -35,7 +35,10 @@ from vllm.compilation.decorators import support_torch_compile ...@@ -35,7 +35,10 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import (
FusedMoE,
fused_moe_make_expert_params_mapping,
)
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
...@@ -514,7 +517,7 @@ class PhiMoEModel(nn.Module): ...@@ -514,7 +517,7 @@ class PhiMoEModel(nn.Module):
return hidden_states return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return FusedMoE.make_expert_params_mapping( return fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="w1", ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2", ckpt_down_proj_name="w2",
......
...@@ -40,7 +40,10 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size ...@@ -40,7 +40,10 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import (
FusedMoE,
fused_moe_make_expert_params_mapping,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -418,7 +421,7 @@ class Qwen2MoeModel(nn.Module): ...@@ -418,7 +421,7 @@ class Qwen2MoeModel(nn.Module):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping( return fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -12,7 +12,9 @@ from vllm.compilation.decorators import support_torch_compile ...@@ -12,7 +12,9 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import (
fused_moe_make_expert_params_mapping,
)
from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -194,7 +196,7 @@ class Qwen3_5MultiTokenPredictor(nn.Module): ...@@ -194,7 +196,7 @@ class Qwen3_5MultiTokenPredictor(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping( expert_params_mapping = fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -43,7 +43,10 @@ from vllm.distributed import ( ...@@ -43,7 +43,10 @@ from vllm.distributed import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import (
FusedMoE,
fused_moe_make_expert_params_mapping,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -516,7 +519,7 @@ class Qwen3MoeModel(nn.Module, EagleModelMixin): ...@@ -516,7 +519,7 @@ class Qwen3MoeModel(nn.Module, EagleModelMixin):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping( return fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -23,7 +23,10 @@ from vllm.distributed import ( ...@@ -23,7 +23,10 @@ from vllm.distributed import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import (
FusedMoE,
fused_moe_make_expert_params_mapping,
)
from vllm.model_executor.layers.layernorm import ( from vllm.model_executor.layers.layernorm import (
GemmaRMSNorm as Qwen3NextRMSNorm, GemmaRMSNorm as Qwen3NextRMSNorm,
) )
...@@ -533,7 +536,7 @@ class Qwen3NextModel(nn.Module, EagleModelMixin): ...@@ -533,7 +536,7 @@ class Qwen3NextModel(nn.Module, EagleModelMixin):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping( return fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -11,7 +11,9 @@ from vllm.compilation.decorators import support_torch_compile ...@@ -11,7 +11,9 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import (
fused_moe_make_expert_params_mapping,
)
from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -145,7 +147,7 @@ class Qwen3NextMultiTokenPredictor(nn.Module): ...@@ -145,7 +147,7 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping( expert_params_mapping = fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -35,7 +35,10 @@ from vllm.distributed import ( ...@@ -35,7 +35,10 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import (
FusedMoE,
fused_moe_make_expert_params_mapping,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
...@@ -529,7 +532,7 @@ class SarvamMLAModel(nn.Module): ...@@ -529,7 +532,7 @@ class SarvamMLAModel(nn.Module):
return hidden_states return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return FusedMoE.make_expert_params_mapping( return fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -18,7 +18,9 @@ from vllm.distributed import ( ...@@ -18,7 +18,9 @@ from vllm.distributed import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import (
FusedMoE,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
......
...@@ -23,7 +23,10 @@ from vllm.distributed import ( ...@@ -23,7 +23,10 @@ from vllm.distributed import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul, SwigluStepAndMul from vllm.model_executor.layers.activation import SiluAndMul, SwigluStepAndMul
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import (
FusedMoE,
fused_moe_make_expert_params_mapping,
)
from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
...@@ -637,7 +640,7 @@ class Step3p5Model(nn.Module): ...@@ -637,7 +640,7 @@ class Step3p5Model(nn.Module):
] ]
# New per-expert format: .moe.experts.E.gate_proj.weight_packed [out, in] # New per-expert format: .moe.experts.E.gate_proj.weight_packed [out, in]
per_expert_mapping = FusedMoE.make_expert_params_mapping( per_expert_mapping = fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -25,7 +25,10 @@ from vllm.config.utils import getattr_iter ...@@ -25,7 +25,10 @@ from vllm.config.utils import getattr_iter
from vllm.distributed import get_dp_group, get_ep_group from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import (
FusedMoE,
fused_moe_make_expert_params_mapping,
)
from vllm.model_executor.models.interfaces import MixtureOfExperts from vllm.model_executor.models.interfaces import MixtureOfExperts
from vllm.model_executor.models.utils import maybe_prefix from vllm.model_executor.models.utils import maybe_prefix
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -179,7 +182,7 @@ class MoEMixin(MixtureOfExperts): ...@@ -179,7 +182,7 @@ class MoEMixin(MixtureOfExperts):
num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts
for gate_proj, down_proj, up_proj in ckpt_names: for gate_proj, down_proj, up_proj in ckpt_names:
expert_mapping.extend( expert_mapping.extend(
FusedMoE.make_expert_params_mapping( fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name=gate_proj, ckpt_gate_proj_name=gate_proj,
ckpt_down_proj_name=down_proj, ckpt_down_proj_name=down_proj,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment