Unverified Commit 9effeb5b authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Support EPLB in FusedMoE (#8448)

parent 1992ef9b
......@@ -47,6 +47,11 @@ class ExpertDistributionRecorder(ABC):
rank: int,
):
if server_args.expert_distribution_recorder_mode is not None:
assert (
expert_location_metadata is not None
), "ExpertLocationMetadata is required for expert distribution recording. One possible"
"reason is that you are using a model that does not support expert distribution"
"recording. Try setting `get_model_config_for_expert_location` in your model."
return _ExpertDistributionRecorderReal(
server_args, expert_location_metadata, rank
)
......
......@@ -82,6 +82,10 @@ class ExpertLocationMetadata:
def init_trivial(server_args: ServerArgs, model_config: ModelConfig):
"""Trivial location - logical expert i corresponds to physical expert i"""
common = ExpertLocationMetadata._init_common(server_args, model_config)
if common is None:
return None
num_physical_experts = common["num_physical_experts"]
model_config_for_expert_location = common["model_config_for_expert_location"]
num_layers = model_config_for_expert_location.num_layers
......@@ -109,6 +113,10 @@ class ExpertLocationMetadata:
physical_to_logical_map = physical_to_logical_map.to(server_args.device)
common = ExpertLocationMetadata._init_common(server_args, model_config)
if common is None:
return None
model_config_for_expert_location = common["model_config_for_expert_location"]
logical_to_all_physical_map = _compute_logical_to_all_physical_map(
physical_to_logical_map,
......@@ -133,6 +141,10 @@ class ExpertLocationMetadata:
logical_count = logical_count.to(server_args.device)
common = ExpertLocationMetadata._init_common(server_args, model_config)
if common is None:
return None
model_config_for_expert_location = common["model_config_for_expert_location"]
num_physical_experts = common["num_physical_experts"]
num_groups = model_config_for_expert_location.num_groups
......@@ -168,6 +180,9 @@ class ExpertLocationMetadata:
ModelConfigForExpertLocation.from_model_config(model_config)
)
if model_config_for_expert_location is None:
return None
num_physical_experts = (
model_config_for_expert_location.num_logical_experts
+ server_args.ep_num_redundant_experts
......@@ -398,10 +413,6 @@ class ModelConfigForExpertLocation:
num_logical_experts: int
num_groups: Optional[int] = None
@staticmethod
def init_dummy():
return ModelConfigForExpertLocation(num_layers=1, num_logical_experts=1)
@staticmethod
def from_model_config(model_config: ModelConfig):
model_class, _ = get_model_architecture(model_config)
......@@ -410,12 +421,12 @@ class ModelConfigForExpertLocation:
model_config.hf_config
)
else:
return ModelConfigForExpertLocation.init_dummy()
return None
def compute_initial_expert_location_metadata(
server_args: ServerArgs, model_config: ModelConfig
) -> ExpertLocationMetadata:
) -> Optional[ExpertLocationMetadata]:
data = server_args.init_expert_location
if data == "trivial":
return ExpertLocationMetadata.init_trivial(server_args, model_config)
......
......@@ -36,6 +36,7 @@ class ExpertLocationDispatchInfo:
def init_new(cls, layer_id: int):
ep_dispatch_algorithm = global_server_args_dict["ep_dispatch_algorithm"]
expert_location_metadata = get_global_expert_location_metadata()
assert expert_location_metadata is not None
if ep_dispatch_algorithm is None:
return None
......
......@@ -50,6 +50,8 @@ class ExpertLocationUpdater:
torch.cuda.empty_cache()
old_expert_location_metadata = get_global_expert_location_metadata()
assert old_expert_location_metadata is not None
_update_expert_weights(
routed_experts_weights_of_layer=routed_experts_weights_of_layer,
old_expert_location_metadata=old_expert_location_metadata,
......
......@@ -183,6 +183,7 @@ class EPMoE(FusedMoE):
hidden_size: int,
intermediate_size: int,
layer_id: int,
num_fused_shared_experts: int = 0,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
......@@ -196,6 +197,7 @@ class EPMoE(FusedMoE):
hidden_size=hidden_size,
intermediate_size=intermediate_size,
top_k=top_k,
num_fused_shared_experts=num_fused_shared_experts,
layer_id=layer_id,
params_dtype=params_dtype,
quant_config=quant_config,
......@@ -728,10 +730,19 @@ class EPMoE(FusedMoE):
shard_id: str,
expert_id: int,
) -> None:
physical_expert_ids = (
get_global_expert_location_metadata().logical_to_all_physical(
self.layer_id, expert_id
global_expert_location_metadata = get_global_expert_location_metadata()
if global_expert_location_metadata is None:
self._weight_loader_impl(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=expert_id,
)
return
physical_expert_ids = global_expert_location_metadata.logical_to_all_physical(
self.layer_id, expert_id
)
for physical_expert_id in physical_expert_ids:
self._weight_loader_physical(
......@@ -778,6 +789,7 @@ class DeepEPMoE(EPMoE):
hidden_size: int,
intermediate_size: int,
layer_id: int,
num_fused_shared_experts: int = 0,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
......@@ -792,6 +804,7 @@ class DeepEPMoE(EPMoE):
hidden_size=hidden_size,
intermediate_size=intermediate_size,
layer_id=layer_id,
num_fused_shared_experts=num_fused_shared_experts,
params_dtype=params_dtype,
quant_config=quant_config,
tp_size=tp_size,
......
......@@ -11,6 +11,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
......@@ -62,8 +63,9 @@ class FusedMoE(torch.nn.Module):
num_experts: int,
hidden_size: int,
intermediate_size: int,
layer_id: int,
top_k: Optional[int] = None,
layer_id: Optional[int] = None,
num_fused_shared_experts: int = 0,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
quant_config: Optional[QuantizationConfig] = None,
......@@ -84,6 +86,7 @@ class FusedMoE(torch.nn.Module):
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.layer_id = layer_id
self.top_k = top_k
self.hidden_size = hidden_size
self.tp_size = (
......@@ -91,6 +94,7 @@ class FusedMoE(torch.nn.Module):
)
self.tp_rank = get_tensor_model_parallel_rank()
self.num_experts = num_experts
self.num_fused_shared_experts = num_fused_shared_experts
self.expert_map = None
if enable_flashinfer_cutlass_moe and quant_config is None:
......@@ -375,6 +379,45 @@ class FusedMoE(torch.nn.Module):
shard_id: str,
expert_id: int,
) -> None:
global_expert_location_metadata = get_global_expert_location_metadata()
if global_expert_location_metadata is None:
self._weight_loader_impl(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=expert_id,
)
return
if expert_id >= self.num_experts - self.num_fused_shared_experts:
# This is a shared expert.
physical_expert_ids = [expert_id]
else:
physical_expert_ids = (
global_expert_location_metadata.logical_to_all_physical(
self.layer_id, expert_id
)
)
for physical_expert_id in physical_expert_ids:
self._weight_loader_physical(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=physical_expert_id,
)
def _weight_loader_physical(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
) -> None:
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
if expert_id == -1:
return
......
......@@ -325,6 +325,7 @@ class DeepseekV2MoE(nn.Module):
num_experts=config.n_routed_experts
+ self.num_fused_shared_experts
+ global_server_args_dict["ep_num_redundant_experts"],
num_fused_shared_experts=self.num_fused_shared_experts,
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
......@@ -2112,6 +2113,7 @@ class DeepseekV2ForCausalLM(nn.Module):
if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True
self.num_fused_shared_experts = 0
log_info_on_rank0(
logger,
f"{disable_reason} Shared experts fusion optimization is disabled.",
......
......@@ -434,6 +434,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
num_experts=config.n_routed_experts
+ self.num_fused_shared_experts
+ global_server_args_dict["ep_num_redundant_experts"],
num_fused_shared_experts=self.num_fused_shared_experts,
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
......@@ -740,10 +741,11 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
global_server_args_dict["enable_deepep_moe"]
or global_server_args_dict["enable_ep_moe"]
):
disable_reason = "Deepseek GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True
self.num_fused_shared_experts = 0
log_info_on_rank0(
logger,
f"{disable_reason} Shared experts fusion optimization is disabled.",
......
......@@ -43,6 +43,7 @@ class GraniteMoeMoE(nn.Module):
top_k: int,
hidden_size: int,
intermediate_size: int,
layer_id: int,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
......@@ -71,6 +72,7 @@ class GraniteMoeMoE(nn.Module):
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
layer_id=layer_id,
params_dtype=params_dtype,
reduce_results=True,
quant_config=quant_config,
......@@ -203,6 +205,7 @@ class GraniteMoeDecoderLayer(nn.Module):
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
layer_id=layer_id,
quant_config=quant_config,
prefix=f"{prefix}.block_sparse_moe",
)
......
......@@ -78,6 +78,7 @@ class Grok1MoE(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_id: int,
num_experts: int,
top_k: int,
hidden_size: int,
......@@ -128,6 +129,7 @@ class Grok1MoE(nn.Module):
self.experts = MoEImpl(
num_experts=num_experts,
top_k=top_k,
layer_id=layer_id,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
......@@ -331,6 +333,7 @@ class Grok1DecoderLayer(nn.Module):
)
self.block_sparse_moe = Grok1MoE(
config=config,
layer_id=layer_id,
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
......
......@@ -163,6 +163,7 @@ class HunYuanSparseMoeBlock(nn.Module):
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
reduce_results=False,
layer_id=layer_id,
quant_config=quant_config,
)
......
......@@ -87,6 +87,7 @@ class Llama4MoE(nn.Module):
def __init__(
self,
config: Llama4TextConfig,
layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
......@@ -114,6 +115,7 @@ class Llama4MoE(nn.Module):
num_experts=config.num_local_experts,
hidden_size=config.hidden_size,
intermediate_size=intermediate_size_moe,
layer_id=layer_id,
reduce_results=False,
quant_config=quant_config,
apply_router_weight_on_input=True,
......@@ -373,6 +375,7 @@ class Llama4DecoderLayer(nn.Module):
if is_moe_layer:
self.feed_forward = Llama4MoE(
config=config,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("feed_forward", prefix),
)
......
......@@ -69,6 +69,7 @@ class MixtralMoE(nn.Module):
top_k: int,
hidden_size: int,
intermediate_size: int,
layer_id: int,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
......@@ -97,6 +98,7 @@ class MixtralMoE(nn.Module):
self.experts = MoEImpl(
num_experts=num_experts,
top_k=top_k,
layer_id=layer_id,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
......@@ -226,6 +228,7 @@ class MixtralDecoderLayer(nn.Module):
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("block_sparse_moe", prefix),
)
......
......@@ -63,6 +63,7 @@ class OlmoeMoE(nn.Module):
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
layer_id: int = 0,
prefix: str = "",
):
super().__init__()
......@@ -89,6 +90,7 @@ class OlmoeMoE(nn.Module):
reduce_results=True,
quant_config=quant_config,
tp_size=tp_size,
layer_id=layer_id,
prefix=add_prefix("experts", prefix),
)
......@@ -224,6 +226,7 @@ class OlmoeDecoderLayer(nn.Module):
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
......
......@@ -210,6 +210,7 @@ class PhiMoE(nn.Module):
self.experts = FusedMoE(
num_experts=num_experts,
top_k=top_k,
layer_id=layer_id,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
reduce_results=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment