Unverified Commit 9effeb5b authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Support EPLB in FusedMoE (#8448)

parent 1992ef9b
...@@ -47,6 +47,11 @@ class ExpertDistributionRecorder(ABC): ...@@ -47,6 +47,11 @@ class ExpertDistributionRecorder(ABC):
rank: int, rank: int,
): ):
if server_args.expert_distribution_recorder_mode is not None: if server_args.expert_distribution_recorder_mode is not None:
assert (
expert_location_metadata is not None
), "ExpertLocationMetadata is required for expert distribution recording. One possible"
"reason is that you are using a model that does not support expert distribution"
"recording. Try setting `get_model_config_for_expert_location` in your model."
return _ExpertDistributionRecorderReal( return _ExpertDistributionRecorderReal(
server_args, expert_location_metadata, rank server_args, expert_location_metadata, rank
) )
......
...@@ -82,6 +82,10 @@ class ExpertLocationMetadata: ...@@ -82,6 +82,10 @@ class ExpertLocationMetadata:
def init_trivial(server_args: ServerArgs, model_config: ModelConfig): def init_trivial(server_args: ServerArgs, model_config: ModelConfig):
"""Trivial location - logical expert i corresponds to physical expert i""" """Trivial location - logical expert i corresponds to physical expert i"""
common = ExpertLocationMetadata._init_common(server_args, model_config) common = ExpertLocationMetadata._init_common(server_args, model_config)
if common is None:
return None
num_physical_experts = common["num_physical_experts"] num_physical_experts = common["num_physical_experts"]
model_config_for_expert_location = common["model_config_for_expert_location"] model_config_for_expert_location = common["model_config_for_expert_location"]
num_layers = model_config_for_expert_location.num_layers num_layers = model_config_for_expert_location.num_layers
...@@ -109,6 +113,10 @@ class ExpertLocationMetadata: ...@@ -109,6 +113,10 @@ class ExpertLocationMetadata:
physical_to_logical_map = physical_to_logical_map.to(server_args.device) physical_to_logical_map = physical_to_logical_map.to(server_args.device)
common = ExpertLocationMetadata._init_common(server_args, model_config) common = ExpertLocationMetadata._init_common(server_args, model_config)
if common is None:
return None
model_config_for_expert_location = common["model_config_for_expert_location"] model_config_for_expert_location = common["model_config_for_expert_location"]
logical_to_all_physical_map = _compute_logical_to_all_physical_map( logical_to_all_physical_map = _compute_logical_to_all_physical_map(
physical_to_logical_map, physical_to_logical_map,
...@@ -133,6 +141,10 @@ class ExpertLocationMetadata: ...@@ -133,6 +141,10 @@ class ExpertLocationMetadata:
logical_count = logical_count.to(server_args.device) logical_count = logical_count.to(server_args.device)
common = ExpertLocationMetadata._init_common(server_args, model_config) common = ExpertLocationMetadata._init_common(server_args, model_config)
if common is None:
return None
model_config_for_expert_location = common["model_config_for_expert_location"] model_config_for_expert_location = common["model_config_for_expert_location"]
num_physical_experts = common["num_physical_experts"] num_physical_experts = common["num_physical_experts"]
num_groups = model_config_for_expert_location.num_groups num_groups = model_config_for_expert_location.num_groups
...@@ -168,6 +180,9 @@ class ExpertLocationMetadata: ...@@ -168,6 +180,9 @@ class ExpertLocationMetadata:
ModelConfigForExpertLocation.from_model_config(model_config) ModelConfigForExpertLocation.from_model_config(model_config)
) )
if model_config_for_expert_location is None:
return None
num_physical_experts = ( num_physical_experts = (
model_config_for_expert_location.num_logical_experts model_config_for_expert_location.num_logical_experts
+ server_args.ep_num_redundant_experts + server_args.ep_num_redundant_experts
...@@ -398,10 +413,6 @@ class ModelConfigForExpertLocation: ...@@ -398,10 +413,6 @@ class ModelConfigForExpertLocation:
num_logical_experts: int num_logical_experts: int
num_groups: Optional[int] = None num_groups: Optional[int] = None
@staticmethod
def init_dummy():
return ModelConfigForExpertLocation(num_layers=1, num_logical_experts=1)
@staticmethod @staticmethod
def from_model_config(model_config: ModelConfig): def from_model_config(model_config: ModelConfig):
model_class, _ = get_model_architecture(model_config) model_class, _ = get_model_architecture(model_config)
...@@ -410,12 +421,12 @@ class ModelConfigForExpertLocation: ...@@ -410,12 +421,12 @@ class ModelConfigForExpertLocation:
model_config.hf_config model_config.hf_config
) )
else: else:
return ModelConfigForExpertLocation.init_dummy() return None
def compute_initial_expert_location_metadata( def compute_initial_expert_location_metadata(
server_args: ServerArgs, model_config: ModelConfig server_args: ServerArgs, model_config: ModelConfig
) -> ExpertLocationMetadata: ) -> Optional[ExpertLocationMetadata]:
data = server_args.init_expert_location data = server_args.init_expert_location
if data == "trivial": if data == "trivial":
return ExpertLocationMetadata.init_trivial(server_args, model_config) return ExpertLocationMetadata.init_trivial(server_args, model_config)
......
...@@ -36,6 +36,7 @@ class ExpertLocationDispatchInfo: ...@@ -36,6 +36,7 @@ class ExpertLocationDispatchInfo:
def init_new(cls, layer_id: int): def init_new(cls, layer_id: int):
ep_dispatch_algorithm = global_server_args_dict["ep_dispatch_algorithm"] ep_dispatch_algorithm = global_server_args_dict["ep_dispatch_algorithm"]
expert_location_metadata = get_global_expert_location_metadata() expert_location_metadata = get_global_expert_location_metadata()
assert expert_location_metadata is not None
if ep_dispatch_algorithm is None: if ep_dispatch_algorithm is None:
return None return None
......
...@@ -50,6 +50,8 @@ class ExpertLocationUpdater: ...@@ -50,6 +50,8 @@ class ExpertLocationUpdater:
torch.cuda.empty_cache() torch.cuda.empty_cache()
old_expert_location_metadata = get_global_expert_location_metadata() old_expert_location_metadata = get_global_expert_location_metadata()
assert old_expert_location_metadata is not None
_update_expert_weights( _update_expert_weights(
routed_experts_weights_of_layer=routed_experts_weights_of_layer, routed_experts_weights_of_layer=routed_experts_weights_of_layer,
old_expert_location_metadata=old_expert_location_metadata, old_expert_location_metadata=old_expert_location_metadata,
......
...@@ -183,6 +183,7 @@ class EPMoE(FusedMoE): ...@@ -183,6 +183,7 @@ class EPMoE(FusedMoE):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
layer_id: int, layer_id: int,
num_fused_shared_experts: int = 0,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
...@@ -196,6 +197,7 @@ class EPMoE(FusedMoE): ...@@ -196,6 +197,7 @@ class EPMoE(FusedMoE):
hidden_size=hidden_size, hidden_size=hidden_size,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
top_k=top_k, top_k=top_k,
num_fused_shared_experts=num_fused_shared_experts,
layer_id=layer_id, layer_id=layer_id,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config, quant_config=quant_config,
...@@ -728,10 +730,19 @@ class EPMoE(FusedMoE): ...@@ -728,10 +730,19 @@ class EPMoE(FusedMoE):
shard_id: str, shard_id: str,
expert_id: int, expert_id: int,
) -> None: ) -> None:
physical_expert_ids = ( global_expert_location_metadata = get_global_expert_location_metadata()
get_global_expert_location_metadata().logical_to_all_physical( if global_expert_location_metadata is None:
self.layer_id, expert_id self._weight_loader_impl(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=expert_id,
) )
return
physical_expert_ids = global_expert_location_metadata.logical_to_all_physical(
self.layer_id, expert_id
) )
for physical_expert_id in physical_expert_ids: for physical_expert_id in physical_expert_ids:
self._weight_loader_physical( self._weight_loader_physical(
...@@ -778,6 +789,7 @@ class DeepEPMoE(EPMoE): ...@@ -778,6 +789,7 @@ class DeepEPMoE(EPMoE):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
layer_id: int, layer_id: int,
num_fused_shared_experts: int = 0,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
...@@ -792,6 +804,7 @@ class DeepEPMoE(EPMoE): ...@@ -792,6 +804,7 @@ class DeepEPMoE(EPMoE):
hidden_size=hidden_size, hidden_size=hidden_size,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
layer_id=layer_id, layer_id=layer_id,
num_fused_shared_experts=num_fused_shared_experts,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size, tp_size=tp_size,
......
...@@ -11,6 +11,7 @@ from sglang.srt.distributed import ( ...@@ -11,6 +11,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
...@@ -62,8 +63,9 @@ class FusedMoE(torch.nn.Module): ...@@ -62,8 +63,9 @@ class FusedMoE(torch.nn.Module):
num_experts: int, num_experts: int,
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
layer_id: int,
top_k: Optional[int] = None, top_k: Optional[int] = None,
layer_id: Optional[int] = None, num_fused_shared_experts: int = 0,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False, reduce_results: bool = False,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
...@@ -84,6 +86,7 @@ class FusedMoE(torch.nn.Module): ...@@ -84,6 +86,7 @@ class FusedMoE(torch.nn.Module):
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
self.layer_id = layer_id
self.top_k = top_k self.top_k = top_k
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.tp_size = ( self.tp_size = (
...@@ -91,6 +94,7 @@ class FusedMoE(torch.nn.Module): ...@@ -91,6 +94,7 @@ class FusedMoE(torch.nn.Module):
) )
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.num_experts = num_experts self.num_experts = num_experts
self.num_fused_shared_experts = num_fused_shared_experts
self.expert_map = None self.expert_map = None
if enable_flashinfer_cutlass_moe and quant_config is None: if enable_flashinfer_cutlass_moe and quant_config is None:
...@@ -375,6 +379,45 @@ class FusedMoE(torch.nn.Module): ...@@ -375,6 +379,45 @@ class FusedMoE(torch.nn.Module):
shard_id: str, shard_id: str,
expert_id: int, expert_id: int,
) -> None: ) -> None:
global_expert_location_metadata = get_global_expert_location_metadata()
if global_expert_location_metadata is None:
self._weight_loader_impl(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=expert_id,
)
return
if expert_id >= self.num_experts - self.num_fused_shared_experts:
# This is a shared expert.
physical_expert_ids = [expert_id]
else:
physical_expert_ids = (
global_expert_location_metadata.logical_to_all_physical(
self.layer_id, expert_id
)
)
for physical_expert_id in physical_expert_ids:
self._weight_loader_physical(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=physical_expert_id,
)
def _weight_loader_physical(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
) -> None:
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
if expert_id == -1: if expert_id == -1:
return return
......
...@@ -325,6 +325,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -325,6 +325,7 @@ class DeepseekV2MoE(nn.Module):
num_experts=config.n_routed_experts num_experts=config.n_routed_experts
+ self.num_fused_shared_experts + self.num_fused_shared_experts
+ global_server_args_dict["ep_num_redundant_experts"], + global_server_args_dict["ep_num_redundant_experts"],
num_fused_shared_experts=self.num_fused_shared_experts,
top_k=config.num_experts_per_tok + self.num_fused_shared_experts, top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
...@@ -2112,6 +2113,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2112,6 +2113,7 @@ class DeepseekV2ForCausalLM(nn.Module):
if disable_reason is not None: if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True global_server_args_dict["disable_shared_experts_fusion"] = True
self.num_fused_shared_experts = 0
log_info_on_rank0( log_info_on_rank0(
logger, logger,
f"{disable_reason} Shared experts fusion optimization is disabled.", f"{disable_reason} Shared experts fusion optimization is disabled.",
......
...@@ -434,6 +434,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -434,6 +434,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
num_experts=config.n_routed_experts num_experts=config.n_routed_experts
+ self.num_fused_shared_experts + self.num_fused_shared_experts
+ global_server_args_dict["ep_num_redundant_experts"], + global_server_args_dict["ep_num_redundant_experts"],
num_fused_shared_experts=self.num_fused_shared_experts,
top_k=config.num_experts_per_tok + self.num_fused_shared_experts, top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
...@@ -740,10 +741,11 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): ...@@ -740,10 +741,11 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
global_server_args_dict["enable_deepep_moe"] global_server_args_dict["enable_deepep_moe"]
or global_server_args_dict["enable_ep_moe"] or global_server_args_dict["enable_ep_moe"]
): ):
disable_reason = "Deepseek GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode." disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
if disable_reason is not None: if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True global_server_args_dict["disable_shared_experts_fusion"] = True
self.num_fused_shared_experts = 0
log_info_on_rank0( log_info_on_rank0(
logger, logger,
f"{disable_reason} Shared experts fusion optimization is disabled.", f"{disable_reason} Shared experts fusion optimization is disabled.",
......
...@@ -43,6 +43,7 @@ class GraniteMoeMoE(nn.Module): ...@@ -43,6 +43,7 @@ class GraniteMoeMoE(nn.Module):
top_k: int, top_k: int,
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
layer_id: int,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
...@@ -71,6 +72,7 @@ class GraniteMoeMoE(nn.Module): ...@@ -71,6 +72,7 @@ class GraniteMoeMoE(nn.Module):
top_k=top_k, top_k=top_k,
hidden_size=hidden_size, hidden_size=hidden_size,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
layer_id=layer_id,
params_dtype=params_dtype, params_dtype=params_dtype,
reduce_results=True, reduce_results=True,
quant_config=quant_config, quant_config=quant_config,
...@@ -203,6 +205,7 @@ class GraniteMoeDecoderLayer(nn.Module): ...@@ -203,6 +205,7 @@ class GraniteMoeDecoderLayer(nn.Module):
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
layer_id=layer_id,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.block_sparse_moe", prefix=f"{prefix}.block_sparse_moe",
) )
......
...@@ -78,6 +78,7 @@ class Grok1MoE(nn.Module): ...@@ -78,6 +78,7 @@ class Grok1MoE(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
layer_id: int,
num_experts: int, num_experts: int,
top_k: int, top_k: int,
hidden_size: int, hidden_size: int,
...@@ -128,6 +129,7 @@ class Grok1MoE(nn.Module): ...@@ -128,6 +129,7 @@ class Grok1MoE(nn.Module):
self.experts = MoEImpl( self.experts = MoEImpl(
num_experts=num_experts, num_experts=num_experts,
top_k=top_k, top_k=top_k,
layer_id=layer_id,
hidden_size=hidden_size, hidden_size=hidden_size,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
params_dtype=params_dtype, params_dtype=params_dtype,
...@@ -331,6 +333,7 @@ class Grok1DecoderLayer(nn.Module): ...@@ -331,6 +333,7 @@ class Grok1DecoderLayer(nn.Module):
) )
self.block_sparse_moe = Grok1MoE( self.block_sparse_moe = Grok1MoE(
config=config, config=config,
layer_id=layer_id,
num_experts=config.num_local_experts, num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
......
...@@ -163,6 +163,7 @@ class HunYuanSparseMoeBlock(nn.Module): ...@@ -163,6 +163,7 @@ class HunYuanSparseMoeBlock(nn.Module):
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
reduce_results=False, reduce_results=False,
layer_id=layer_id,
quant_config=quant_config, quant_config=quant_config,
) )
......
...@@ -87,6 +87,7 @@ class Llama4MoE(nn.Module): ...@@ -87,6 +87,7 @@ class Llama4MoE(nn.Module):
def __init__( def __init__(
self, self,
config: Llama4TextConfig, config: Llama4TextConfig,
layer_id: int,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): ):
...@@ -114,6 +115,7 @@ class Llama4MoE(nn.Module): ...@@ -114,6 +115,7 @@ class Llama4MoE(nn.Module):
num_experts=config.num_local_experts, num_experts=config.num_local_experts,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=intermediate_size_moe, intermediate_size=intermediate_size_moe,
layer_id=layer_id,
reduce_results=False, reduce_results=False,
quant_config=quant_config, quant_config=quant_config,
apply_router_weight_on_input=True, apply_router_weight_on_input=True,
...@@ -373,6 +375,7 @@ class Llama4DecoderLayer(nn.Module): ...@@ -373,6 +375,7 @@ class Llama4DecoderLayer(nn.Module):
if is_moe_layer: if is_moe_layer:
self.feed_forward = Llama4MoE( self.feed_forward = Llama4MoE(
config=config, config=config,
layer_id=layer_id,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("feed_forward", prefix), prefix=add_prefix("feed_forward", prefix),
) )
......
...@@ -69,6 +69,7 @@ class MixtralMoE(nn.Module): ...@@ -69,6 +69,7 @@ class MixtralMoE(nn.Module):
top_k: int, top_k: int,
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
layer_id: int,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
...@@ -97,6 +98,7 @@ class MixtralMoE(nn.Module): ...@@ -97,6 +98,7 @@ class MixtralMoE(nn.Module):
self.experts = MoEImpl( self.experts = MoEImpl(
num_experts=num_experts, num_experts=num_experts,
top_k=top_k, top_k=top_k,
layer_id=layer_id,
hidden_size=hidden_size, hidden_size=hidden_size,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
params_dtype=params_dtype, params_dtype=params_dtype,
...@@ -226,6 +228,7 @@ class MixtralDecoderLayer(nn.Module): ...@@ -226,6 +228,7 @@ class MixtralDecoderLayer(nn.Module):
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
layer_id=layer_id,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("block_sparse_moe", prefix), prefix=add_prefix("block_sparse_moe", prefix),
) )
......
...@@ -63,6 +63,7 @@ class OlmoeMoE(nn.Module): ...@@ -63,6 +63,7 @@ class OlmoeMoE(nn.Module):
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
layer_id: int = 0,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -89,6 +90,7 @@ class OlmoeMoE(nn.Module): ...@@ -89,6 +90,7 @@ class OlmoeMoE(nn.Module):
reduce_results=True, reduce_results=True,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size, tp_size=tp_size,
layer_id=layer_id,
prefix=add_prefix("experts", prefix), prefix=add_prefix("experts", prefix),
) )
...@@ -224,6 +226,7 @@ class OlmoeDecoderLayer(nn.Module): ...@@ -224,6 +226,7 @@ class OlmoeDecoderLayer(nn.Module):
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
layer_id=layer_id,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp", prefix), prefix=add_prefix("mlp", prefix),
) )
......
...@@ -210,6 +210,7 @@ class PhiMoE(nn.Module): ...@@ -210,6 +210,7 @@ class PhiMoE(nn.Module):
self.experts = FusedMoE( self.experts = FusedMoE(
num_experts=num_experts, num_experts=num_experts,
top_k=top_k, top_k=top_k,
layer_id=layer_id,
hidden_size=hidden_size, hidden_size=hidden_size,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
reduce_results=True, reduce_results=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment