Unverified Commit 1c2c1eb8 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[MoE Refactor] Rename FusedMoE.make_expert_params_mapping to...


[MoE Refactor] Rename FusedMoE.make_expert_params_mapping to fused_moe_make_expert_params_mapping (#40671)
Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent 8824f50f
...@@ -19,6 +19,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( ...@@ -19,6 +19,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoE,
FusedMoeWeightScaleSupported, FusedMoeWeightScaleSupported,
fused_moe_make_expert_params_mapping,
) )
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEActivationFormat,
...@@ -65,6 +66,7 @@ __all__ = [ ...@@ -65,6 +66,7 @@ __all__ = [
"RoutingMethodType", "RoutingMethodType",
"activation_without_mul", "activation_without_mul",
"apply_moe_activation", "apply_moe_activation",
"fused_moe_make_expert_params_mapping",
"override_config", "override_config",
"get_config", "get_config",
] ]
......
...@@ -1618,6 +1618,25 @@ class FusedMoE(PluggableLayer): ...@@ -1618,6 +1618,25 @@ class FusedMoE(PluggableLayer):
return s return s
# This is a temporary forwarding method which will be removed/modified layer.
def fused_moe_make_expert_params_mapping(
model: torch.nn.Module,
ckpt_gate_proj_name: str,
ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
num_experts: int,
num_redundant_experts: int = 0,
) -> list[tuple[str, str, int, str]]:
return FusedMoE.make_expert_params_mapping(
model,
ckpt_gate_proj_name,
ckpt_down_proj_name,
ckpt_up_proj_name,
num_experts,
num_redundant_experts,
)
# Mark the FusedMoE weight_loader as supporting MoE-specific parameters # Mark the FusedMoE weight_loader as supporting MoE-specific parameters
# to avoid expensive runtime reflection in model loading code # to avoid expensive runtime reflection in model loading code
FusedMoE.weight_loader.supports_moe_loading = True # type: ignore[attr-defined] FusedMoE.weight_loader.supports_moe_loading = True # type: ignore[attr-defined]
...@@ -42,7 +42,10 @@ from vllm.distributed import ( ...@@ -42,7 +42,10 @@ from vllm.distributed import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import (
FusedMoE,
fused_moe_make_expert_params_mapping,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
...@@ -916,7 +919,7 @@ class AXK1ForCausalLM( ...@@ -916,7 +919,7 @@ class AXK1ForCausalLM(
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping( return fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
...@@ -950,7 +953,7 @@ class AXK1ForCausalLM( ...@@ -950,7 +953,7 @@ class AXK1ForCausalLM(
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping( expert_params_mapping = fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -18,7 +18,10 @@ from vllm.distributed import ( ...@@ -18,7 +18,10 @@ from vllm.distributed import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import (
FusedMoE,
fused_moe_make_expert_params_mapping,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
...@@ -479,7 +482,7 @@ class AfmoeModel(nn.Module, EagleModelMixin): ...@@ -479,7 +482,7 @@ class AfmoeModel(nn.Module, EagleModelMixin):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping( return fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -18,7 +18,10 @@ from vllm.distributed import ( ...@@ -18,7 +18,10 @@ from vllm.distributed import (
) )
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk from vllm.model_executor.layers.fused_moe import (
fused_experts,
fused_topk,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
......
...@@ -14,7 +14,9 @@ from vllm.config.multimodal import BaseDummyOptions ...@@ -14,7 +14,9 @@ from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
from vllm.inputs import MultiModalDataDict from vllm.inputs import MultiModalDataDict
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import (
FusedMoE,
)
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
......
...@@ -41,7 +41,10 @@ from vllm.distributed import ( ...@@ -41,7 +41,10 @@ from vllm.distributed import (
) )
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import (
FusedMoE,
fused_moe_make_expert_params_mapping,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -461,7 +464,7 @@ class BailingMoeModel(nn.Module): ...@@ -461,7 +464,7 @@ class BailingMoeModel(nn.Module):
return hidden_states return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return FusedMoE.make_expert_params_mapping( return fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -21,7 +21,10 @@ from vllm.model_executor.layers.fla.ops.layernorm_guard import ( ...@@ -21,7 +21,10 @@ from vllm.model_executor.layers.fla.ops.layernorm_guard import (
RMSNormGated, RMSNormGated,
layernorm_fn, layernorm_fn,
) )
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import (
FusedMoE,
fused_moe_make_expert_params_mapping,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
...@@ -990,7 +993,7 @@ class BailingMoeV25Model(nn.Module): ...@@ -990,7 +993,7 @@ class BailingMoeV25Model(nn.Module):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
"""Get expert parameter mapping for MoE layers.""" """Get expert parameter mapping for MoE layers."""
return FusedMoE.make_expert_params_mapping( return fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -15,7 +15,9 @@ from vllm.distributed import ( ...@@ -15,7 +15,9 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import (
FusedMoE,
)
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
......
...@@ -8,7 +8,9 @@ import torch.nn as nn ...@@ -8,7 +8,9 @@ import torch.nn as nn
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import (
fused_moe_make_expert_params_mapping,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -105,7 +107,7 @@ class DeepseekV2Model(nn.Module): ...@@ -105,7 +107,7 @@ class DeepseekV2Model(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping( expert_params_mapping = fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -11,7 +11,9 @@ from vllm._aiter_ops import rocm_aiter_ops ...@@ -11,7 +11,9 @@ from vllm._aiter_ops import rocm_aiter_ops
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import (
fused_moe_make_expert_params_mapping,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -252,7 +254,7 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts): ...@@ -252,7 +254,7 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
] ]
stacked_params_mapping.extend(indexer_fused_mapping) stacked_params_mapping.extend(indexer_fused_mapping)
expert_params_mapping = FusedMoE.make_expert_params_mapping( expert_params_mapping = fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -51,6 +51,7 @@ from vllm.model_executor.layers.fused_moe import ( ...@@ -51,6 +51,7 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoE,
GateLinear, GateLinear,
RoutingMethodType, RoutingMethodType,
fused_moe_make_expert_params_mapping,
) )
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
...@@ -1432,7 +1433,7 @@ class DeepseekV2ForCausalLM( ...@@ -1432,7 +1433,7 @@ class DeepseekV2ForCausalLM(
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping( return fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
...@@ -1474,7 +1475,7 @@ class DeepseekV2ForCausalLM( ...@@ -1474,7 +1475,7 @@ class DeepseekV2ForCausalLM(
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping( expert_params_mapping = fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -40,7 +40,10 @@ from vllm.distributed import ( ...@@ -40,7 +40,10 @@ from vllm.distributed import (
) )
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import (
FusedMoE,
fused_moe_make_expert_params_mapping,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -413,7 +416,7 @@ class Dots1Model(nn.Module): ...@@ -413,7 +416,7 @@ class Dots1Model(nn.Module):
return hidden_states return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return FusedMoE.make_expert_params_mapping( return fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -42,7 +42,10 @@ from vllm.distributed import ( ...@@ -42,7 +42,10 @@ from vllm.distributed import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import (
FusedMoE,
fused_moe_make_expert_params_mapping,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -485,7 +488,7 @@ class Ernie4_5_MoeModel(nn.Module): ...@@ -485,7 +488,7 @@ class Ernie4_5_MoeModel(nn.Module):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping( return fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -36,7 +36,10 @@ from vllm.config import CacheConfig, VllmConfig ...@@ -36,7 +36,10 @@ from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import (
FusedMoE,
fused_moe_make_expert_params_mapping,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
QKVParallelLinear, QKVParallelLinear,
...@@ -649,7 +652,7 @@ class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP): ...@@ -649,7 +652,7 @@ class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP):
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping( expert_params_mapping = fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -30,7 +30,10 @@ from vllm.distributed import ( ...@@ -30,7 +30,10 @@ from vllm.distributed import (
get_pp_group, get_pp_group,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import (
FusedMoE,
fused_moe_make_expert_params_mapping,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
...@@ -326,7 +329,7 @@ class ExaoneMoeModel(nn.Module): ...@@ -326,7 +329,7 @@ class ExaoneMoeModel(nn.Module):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping( return fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -20,7 +20,9 @@ from torch import nn ...@@ -20,7 +20,9 @@ from torch import nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import (
FusedMoE,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.models.olmoe import OlmoeAttention, OlmoeForCausalLM from vllm.model_executor.models.olmoe import OlmoeAttention, OlmoeForCausalLM
......
...@@ -37,7 +37,10 @@ from vllm.forward_context import get_forward_context ...@@ -37,7 +37,10 @@ from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear from vllm.model_executor.layers.fused_moe import (
FusedMoE,
GateLinear,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
......
...@@ -42,7 +42,10 @@ from vllm.distributed import ( ...@@ -42,7 +42,10 @@ from vllm.distributed import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import (
FusedMoE,
fused_moe_make_expert_params_mapping,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -466,7 +469,7 @@ class Glm4MoeModel(nn.Module): ...@@ -466,7 +469,7 @@ class Glm4MoeModel(nn.Module):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping( return fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
...@@ -41,7 +41,9 @@ from vllm.distributed import ( ...@@ -41,7 +41,9 @@ from vllm.distributed import (
get_pp_group, get_pp_group,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import (
fused_moe_make_expert_params_mapping,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -308,7 +310,7 @@ class Glm4MoeLiteModel(nn.Module): ...@@ -308,7 +310,7 @@ class Glm4MoeLiteModel(nn.Module):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping( return fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
...@@ -334,7 +336,7 @@ class Glm4MoeLiteModel(nn.Module): ...@@ -334,7 +336,7 @@ class Glm4MoeLiteModel(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping( expert_params_mapping = fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
...@@ -616,7 +618,7 @@ class Glm4MoeLiteForCausalLM( ...@@ -616,7 +618,7 @@ class Glm4MoeLiteForCausalLM(
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping( return fused_moe_make_expert_params_mapping(
self, self,
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment