Unverified Commit e6f11356 authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

support eplb for qwen3 (#6533)

parent 7b02c326
...@@ -65,6 +65,7 @@ def fused_topk( ...@@ -65,6 +65,7 @@ def fused_topk(
gating_output: torch.Tensor, gating_output: torch.Tensor,
topk: int, topk: int,
renormalize: bool, renormalize: bool,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
): ):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
...@@ -88,7 +89,7 @@ def fused_topk( ...@@ -88,7 +89,7 @@ def fused_topk(
if renormalize: if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
return topk_weights, topk_ids return topk_weights, topk_ids
...@@ -355,12 +356,13 @@ def select_experts( ...@@ -355,12 +356,13 @@ def select_experts(
assert ( assert (
num_token_non_padded is None num_token_non_padded is None
), "num_token_non_padded is not yet supported in fused_topk" ), "num_token_non_padded is not yet supported in fused_topk"
assert expert_location_dispatch_info is None # Qwen3MOE uses fused_topk
topk_weights, topk_ids = fused_topk( topk_weights, topk_ids = fused_topk(
hidden_states=hidden_states, hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
topk=top_k, topk=top_k,
renormalize=renormalize, renormalize=renormalize,
expert_location_dispatch_info=expert_location_dispatch_info,
) )
else: else:
assert ( assert (
......
...@@ -690,7 +690,9 @@ def _convert_global_physical_count_to_logical_count( ...@@ -690,7 +690,9 @@ def _convert_global_physical_count_to_logical_count(
) )
logical_count.scatter_add_( logical_count.scatter_add_(
dim=2, dim=2,
index=physical_to_logical_map.unsqueeze(0).expand(dim_extra, -1, -1), index=physical_to_logical_map.unsqueeze(0)
.expand(dim_extra, -1, -1)
.to(torch.int64),
src=global_physical_count, src=global_physical_count,
) )
return logical_count return logical_count
......
...@@ -55,7 +55,7 @@ from sglang.srt.layers.linear import ( ...@@ -55,7 +55,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
...@@ -67,6 +67,8 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -67,6 +67,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch, ForwardBatch,
...@@ -86,28 +88,25 @@ logger = logging.getLogger(__name__) ...@@ -86,28 +88,25 @@ logger = logging.getLogger(__name__)
class Qwen3MoeSparseMoeBlock(nn.Module): class Qwen3MoeSparseMoeBlock(nn.Module):
def __init__( def __init__(
self, self,
layer_id: int,
config: Qwen3MoeConfig, config: Qwen3MoeConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.layer_id = layer_id
if self.tp_size > config.num_experts: if self.tp_size > config.num_experts:
raise ValueError( raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than " f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}." f"the number of experts {config.num_experts}."
) )
MoEImpl = ( self.experts = get_moe_impl_class()(
DeepEPMoE num_experts=config.num_experts
if global_server_args_dict["enable_deepep_moe"] + global_server_args_dict["ep_num_redundant_experts"],
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
)
self.experts = MoEImpl(
num_experts=config.num_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
layer_id=layer_id,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
renormalize=config.norm_topk_prob, renormalize=config.norm_topk_prob,
...@@ -131,7 +130,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -131,7 +130,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
if global_server_args_dict["enable_deepep_moe"]: if global_server_args_dict["enable_deepep_moe"]:
# TODO: we will support tp < ep in the future # TODO: we will support tp < ep in the future
self.ep_size = get_tensor_model_parallel_world_size() self.ep_size = get_tensor_model_parallel_world_size()
self.num_experts = config.num_experts self.num_experts = (
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
)
self.top_k = config.num_experts_per_tok self.top_k = config.num_experts_per_tok
self.renormalize = config.norm_topk_prob self.renormalize = config.norm_topk_prob
...@@ -139,7 +140,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -139,7 +140,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
group=parallel_state.get_tp_group().device_group, group=parallel_state.get_tp_group().device_group,
router_topk=self.top_k, router_topk=self.top_k,
permute_fusion=True, permute_fusion=True,
num_experts=config.num_experts, num_experts=self.num_experts,
num_local_experts=config.num_experts // self.tp_size, num_local_experts=config.num_experts // self.tp_size,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
params_dtype=config.torch_dtype, params_dtype=config.torch_dtype,
...@@ -157,8 +158,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -157,8 +158,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
else: else:
return self.forward_deepep(hidden_states, forward_mode) return self.forward_deepep(hidden_states, forward_mode)
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: def get_moe_weights(self):
return [
x.data
for name, x in self.experts.named_parameters()
if name not in ["correction_bias"]
]
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
...@@ -189,6 +196,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -189,6 +196,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
top_k=self.top_k, top_k=self.top_k,
use_grouped_topk=False, use_grouped_topk=False,
renormalize=self.renormalize, renormalize=self.renormalize,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
),
) )
else: else:
topk_idx = torch.full( topk_idx = torch.full(
...@@ -408,6 +418,7 @@ class Qwen3MoeDecoderLayer(nn.Module): ...@@ -408,6 +418,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
if self.info.is_sparse: if self.info.is_sparse:
self.mlp = Qwen3MoeSparseMoeBlock( self.mlp = Qwen3MoeSparseMoeBlock(
layer_id=self.layer_id,
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp", prefix), prefix=add_prefix("mlp", prefix),
...@@ -685,15 +696,7 @@ class Qwen3MoeForCausalLM(nn.Module): ...@@ -685,15 +696,7 @@ class Qwen3MoeForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
# Params for weights, fp8 weight scales, fp8 activation scales expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
# (param_name, weight_name, expert_id, shard_id)
MoEImpl = (
DeepEPMoE
if global_server_args_dict["enable_deepep_moe"]
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
)
expert_params_mapping = MoEImpl.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
...@@ -770,5 +773,19 @@ class Qwen3MoeForCausalLM(nn.Module): ...@@ -770,5 +773,19 @@ class Qwen3MoeForCausalLM(nn.Module):
else: else:
logger.warning(f"Parameter {name} not found in params_dict") logger.warning(f"Parameter {name} not found in params_dict")
self.routed_experts_weights_of_layer = {
layer_id: layer.mlp.get_moe_weights()
for layer_id, layer in enumerate(self.model.layers)
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock)
}
@classmethod
def get_model_config_for_expert_location(cls, config):
return ModelConfigForExpertLocation(
num_layers=config.num_hidden_layers,
num_logical_experts=config.num_experts,
num_groups=None,
)
EntryClass = Qwen3MoeForCausalLM EntryClass = Qwen3MoeForCausalLM
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment