Unverified Commit 5e584ce9 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[MoE Refactor] Remove SharedFusedMoE class (#35782)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent 1842447c
......@@ -37,7 +37,7 @@ from vllm.distributed.parallel_state import (
get_eplb_group,
)
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE, fused_experts
from vllm.model_executor.layers.fused_moe import FusedMoE, fused_experts
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.router.router_factory import (
......@@ -858,11 +858,7 @@ def make_fused_moe_layer(
quant_config, qw = make_quant_config(quantization, w1, w2, global_num_experts)
kwargs = dict()
if shared_experts is None:
builder = FusedMoE
else:
builder = SharedFusedMoE
kwargs["shared_experts"] = shared_experts
kwargs["shared_experts"] = shared_experts
# Add gate and routed_input_transform if provided
if gate is not None:
......@@ -872,7 +868,7 @@ def make_fused_moe_layer(
kwargs["routed_input_transform"] = routed_input_transform
kwargs["routed_output_transform"] = routed_output_transform
layer = builder(
layer = FusedMoE(
num_experts=global_num_experts,
top_k=top_k,
hidden_size=hidden_size,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for SharedFusedMoE with routed_input_transform.
Tests for FusedMoE with routed_input_transform.
Verifies that applying routed_input_transform inside SharedFusedMoE
Verifies that applying routed_input_transform inside FusedMoE
produces the same results as applying the transform manually outside.
"""
......@@ -13,7 +13,7 @@ import torch.nn as nn
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_torch_equal_or_newer, set_random_seed
......@@ -133,9 +133,9 @@ def test_routed_input_transform_inside_vs_outside(
workspace_init,
monkeypatch,
):
"""Compare SharedFusedMoE with transform inside vs manually applying outside.
Method A (inside): SharedFusedMoE with routed_input_transform
Method B (outside): Manually transform, then SharedFusedMoE without transform
"""Compare FusedMoE with transform inside vs manually applying outside.
Method A (inside): FusedMoE with routed_input_transform
Method B (outside): Manually transform, then FusedMoE without transform
"""
if current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1" if use_rocm_aiter else "0")
......@@ -157,8 +157,8 @@ def test_routed_input_transform_inside_vs_outside(
routed_transform = SimpleLinear(hidden_size, latent_size, dtype)
with set_current_vllm_config(vllm_config):
# Method A: SharedFusedMoE WITH routed_input_transform
moe_with_transform = SharedFusedMoE(
# Method A: FusedMoE WITH routed_input_transform
moe_with_transform = FusedMoE(
shared_experts=shared_experts,
routed_input_transform=routed_transform,
num_experts=num_experts,
......@@ -173,9 +173,9 @@ def test_routed_input_transform_inside_vs_outside(
prefix="moe_with_transform",
)
# Method B: SharedFusedMoE WITHOUT routed_input_transform
# Method B: FusedMoE WITHOUT routed_input_transform
# Note: shared_experts=None because when transform is done outside,
moe_without_transform = SharedFusedMoE(
moe_without_transform = FusedMoE(
shared_experts=None,
routed_input_transform=None,
num_experts=num_experts,
......
......@@ -7,6 +7,8 @@ import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from vllm.utils import is_moe_layer
class Cache:
def __init__(self):
......@@ -317,16 +319,7 @@ class DeviceCommunicatorBase:
if not self.is_ep_communicator:
return
moe_modules = [
module
for module in model.modules()
# TODO(bnell): Should use isinstance but can't. Maybe search for
# presence of quant_method.maybe_init_modular_kernel?
if (
module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE"
)
]
moe_modules = [module for module in model.modules() if is_moe_layer(module)]
for module in moe_modules:
module.maybe_init_modular_kernel()
......
......@@ -38,6 +38,7 @@ from vllm.distributed.parallel_state import (
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig
from vllm.utils import is_moe_layer
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
from vllm.v1.worker.workspace import lock_workspace, unlock_workspace
......@@ -319,10 +320,7 @@ class ElasticEPScalingExecutor:
moe_modules = [
module
for module in self.worker.model_runner.model.modules()
if (
module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE"
)
if is_moe_layer(module)
]
num_local_experts = moe_modules[0].moe_config.num_local_experts
assert all(
......
......@@ -610,7 +610,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
# source_layer is FusedMoE or SharedFusedMoE
# source_layer is FusedMoE
return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 2
......@@ -772,5 +772,5 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
model_config: PretrainedConfig | None = None,
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
# source_layer is FusedMoE or SharedFusedMoE
# source_layer is FusedMoE
return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 1
......@@ -29,7 +29,6 @@ from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
)
......@@ -64,7 +63,6 @@ __all__ = [
"FusedMoEPrepareAndFinalizeModular",
"GateLinear",
"RoutingMethodType",
"SharedFusedMoE",
"activation_without_mul",
"apply_moe_activation",
"override_config",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
# TODO(bnell): Remove this entirely
class SharedFusedMoE(FusedMoE):
"""
A FusedMoE operation that also computes the results of shared experts.
If an all2all communicator is being used the shared expert computation
can be interleaved with the fused all2all dispatch communication step.
"""
def forward(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
return super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
)
......@@ -42,7 +42,7 @@ from vllm.distributed import (
)
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
......@@ -163,7 +163,7 @@ class AXK1MoE(nn.Module):
prefix=f"{prefix}.shared_experts",
)
self.experts = SharedFusedMoE(
self.experts = FusedMoE(
shared_experts=self.shared_experts,
gate=self.gate,
num_experts=config.n_routed_experts,
......@@ -916,7 +916,7 @@ class AXK1ForCausalLM(
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return SharedFusedMoE.make_expert_params_mapping(
return FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
......@@ -950,7 +950,7 @@ class AXK1ForCausalLM(
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
......
......@@ -18,7 +18,7 @@ from vllm.distributed import (
)
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
......@@ -124,8 +124,8 @@ class AfmoeMoE(nn.Module):
prefix=f"{prefix}.shared_experts",
)
# Routed experts using SharedFusedMoE
self.experts = SharedFusedMoE(
# Routed experts using FusedMoE
self.experts = FusedMoE(
shared_experts=self.shared_experts,
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
......@@ -479,7 +479,7 @@ class AfmoeModel(nn.Module, EagleModelMixin):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return SharedFusedMoE.make_expert_params_mapping(
return FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
......@@ -637,7 +637,7 @@ class AfmoeForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
self.num_moe_layers = config.num_hidden_layers - config.num_dense_layers
self.num_expert_groups = config.n_group
self.moe_layers: list[SharedFusedMoE] = []
self.moe_layers: list[FusedMoE] = []
example_moe = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
......
......@@ -14,7 +14,7 @@ from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.inputs import MultiModalDataDict
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
......@@ -214,7 +214,7 @@ class AriaProjector(nn.Module):
return out
class AriaFusedMoE(SharedFusedMoE):
class AriaFusedMoE(FusedMoE):
def weight_loader(
self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_id: str
) -> None:
......
......@@ -41,7 +41,7 @@ from vllm.distributed import (
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
......@@ -285,7 +285,7 @@ class BailingMoE(nn.Module):
else:
self.shared_experts = None
self.experts = SharedFusedMoE(
self.experts = FusedMoE(
shared_experts=self.shared_experts,
num_experts=self.num_experts,
top_k=self.top_k,
......@@ -461,7 +461,7 @@ class BailingMoeModel(nn.Module):
return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return SharedFusedMoE.make_expert_params_mapping(
return FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
......
......@@ -21,7 +21,7 @@ from vllm.model_executor.layers.fla.ops.layernorm_guard import (
RMSNormGated,
layernorm_fn,
)
from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
......@@ -351,8 +351,8 @@ class BailingMoeV25(nn.Module):
else:
self.shared_experts = None
# Routed experts using SharedFusedMoE
self.experts = SharedFusedMoE(
# Routed experts using FusedMoE
self.experts = FusedMoE(
shared_experts=self.shared_experts,
num_experts=self.num_experts,
top_k=self.top_k,
......
......@@ -11,7 +11,7 @@ from vllm._aiter_ops import rocm_aiter_ops
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
......@@ -252,7 +252,7 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
]
stacked_params_mapping.extend(indexer_fused_mapping)
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
......
......@@ -48,9 +48,9 @@ from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
GateLinear,
RoutingMethodType,
SharedFusedMoE,
)
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
from vllm.model_executor.layers.linear import (
......@@ -311,7 +311,7 @@ class DeepseekV2MoE(nn.Module):
prefix=f"{prefix}.shared_experts",
)
self.experts = SharedFusedMoE(
self.experts = FusedMoE(
shared_experts=self.shared_experts,
gate=self.gate,
num_experts=config.n_routed_experts,
......@@ -1432,7 +1432,7 @@ class DeepseekV2ForCausalLM(
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return SharedFusedMoE.make_expert_params_mapping(
return FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
......@@ -1474,7 +1474,7 @@ class DeepseekV2ForCausalLM(
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
......
......@@ -40,7 +40,7 @@ from vllm.distributed import (
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
......@@ -155,7 +155,7 @@ class Dots1MoE(nn.Module):
else:
self.shared_experts = None
self.experts = SharedFusedMoE(
self.experts = FusedMoE(
shared_experts=self.shared_experts,
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
......@@ -413,7 +413,7 @@ class Dots1Model(nn.Module):
return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return SharedFusedMoE.make_expert_params_mapping(
return FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
......
......@@ -42,7 +42,7 @@ from vllm.distributed import (
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
......@@ -188,7 +188,7 @@ class Ernie4_5_MoeMoE(nn.Module):
else:
self.shared_experts = None
self.experts = SharedFusedMoE(
self.experts = FusedMoE(
shared_experts=self.shared_experts,
num_experts=config.moe_num_experts,
top_k=config.moe_k,
......@@ -485,7 +485,7 @@ class Ernie4_5_MoeModel(nn.Module):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return SharedFusedMoE.make_expert_params_mapping(
return FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
......@@ -667,7 +667,7 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExpe
self.num_moe_layers = len(moe_layers_indices)
self.num_expert_groups = 1
self.moe_layers: list[SharedFusedMoE] = []
self.moe_layers: list[FusedMoE] = []
example_moe = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
......
......@@ -36,7 +36,7 @@ from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
QKVParallelLinear,
......@@ -257,7 +257,7 @@ class Ernie4_5_VLMoeMoE(nn.Module):
prefix=f"{prefix}.text_experts_gate",
)
self.text_experts = SharedFusedMoE(
self.text_experts = FusedMoE(
shared_experts=self.shared_experts,
num_experts=config.moe_num_experts[0],
top_k=config.moe_k,
......@@ -294,7 +294,7 @@ class Ernie4_5_VLMoeMoE(nn.Module):
prefix=f"{prefix}.vision_experts_gate",
)
self.vision_experts = SharedFusedMoE(
self.vision_experts = FusedMoE(
shared_experts=self.shared_experts,
num_experts=config.moe_num_experts[1],
top_k=config.moe_k,
......@@ -649,7 +649,7 @@ class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
......
......@@ -31,7 +31,6 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
......@@ -130,7 +129,7 @@ class ExaoneMoe(nn.Module):
else:
self.shared_experts = None
self.experts = SharedFusedMoE(
self.experts = FusedMoE(
shared_experts=self.shared_experts,
gate=self.gate,
num_experts=self.n_routed_experts,
......
......@@ -42,7 +42,7 @@ from vllm.distributed import (
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
......@@ -178,7 +178,7 @@ class Glm4MoE(nn.Module):
else:
self.shared_experts = None
self.experts = SharedFusedMoE(
self.experts = FusedMoE(
shared_experts=self.shared_experts,
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
......@@ -466,7 +466,7 @@ class Glm4MoeModel(nn.Module):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return SharedFusedMoE.make_expert_params_mapping(
return FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
......
......@@ -41,7 +41,7 @@ from vllm.distributed import (
get_pp_group,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -308,7 +308,7 @@ class Glm4MoeLiteModel(nn.Module):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return SharedFusedMoE.make_expert_params_mapping(
return FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
......@@ -334,7 +334,7 @@ class Glm4MoeLiteModel(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
......@@ -616,7 +616,7 @@ class Glm4MoeLiteForCausalLM(
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return SharedFusedMoE.make_expert_params_mapping(
return FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment