Unverified Commit 29589512 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[6/N] MoE Refactor: Cleanup MoE-related configs (#8849)

parent 584e1ab2
......@@ -50,7 +50,6 @@ from sglang.srt.layers.communicator import (
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
is_dp_attention_enabled,
)
from sglang.srt.layers.layernorm import RMSNorm
......@@ -61,9 +60,10 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import (
......@@ -336,30 +336,6 @@ class DeepseekV2MoE(nn.Module):
quant_config=quant_config,
routed_scaling_factor=self.routed_scaling_factor,
prefix=add_prefix("experts", prefix),
**(
dict(deepep_mode=global_server_args_dict["deepep_mode"])
if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {}
),
# Additional args for FusedMoE
**(
dict(
enable_flashinfer_cutlass_moe=True,
)
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {}
),
**(
dict(
renormalize=config.norm_topk_prob,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
)
if should_use_flashinfer_trtllm_moe()
else {}
),
)
self.shared_experts_is_int8 = False
......@@ -377,7 +353,7 @@ class DeepseekV2MoE(nn.Module):
prefix=add_prefix("shared_experts", prefix),
**(
dict(tp_rank=0, tp_size=1)
if global_server_args_dict["moe_a2a_backend"].is_deepep()
if get_moe_a2a_backend().is_deepep()
else {}
),
)
......@@ -407,7 +383,7 @@ class DeepseekV2MoE(nn.Module):
self.top_k = config.num_experts_per_tok
if global_server_args_dict["moe_a2a_backend"].is_deepep():
if get_moe_a2a_backend().is_deepep():
# TODO: we will support tp < ep in the future
self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = (
......@@ -431,12 +407,12 @@ class DeepseekV2MoE(nn.Module):
num_local_experts=config.n_routed_experts // self.tp_size,
hidden_size=config.hidden_size,
params_dtype=config.torch_dtype,
deepep_mode=global_server_args_dict["deepep_mode"],
deepep_mode=get_deepep_mode(),
async_finish=True,
return_recv_hook=True,
)
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
def get_moe_weights(self):
return [
......@@ -484,13 +460,7 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states}
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
if should_use_flashinfer_trtllm_moe():
kwargs["topk_output"] = (self.topk, router_logits)
else:
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(**kwargs)
if not _is_cuda:
......@@ -520,13 +490,7 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states}
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
if should_use_flashinfer_trtllm_moe():
kwargs["topk_output"] = (self.topk, router_logits)
else:
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(**kwargs)
if not _is_cuda and not _use_aiter:
......@@ -2478,17 +2442,15 @@ class DeepseekV2ForCausalLM(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
)
if self.quant_config and self.quant_config.get_name() == "w4afp8":
expert_params_mapping += (
get_moe_impl_class().make_expert_input_scale_params_mapping(
num_experts=self.config.n_routed_experts
)
expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping(
num_experts=self.config.n_routed_experts
)
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
......
......@@ -31,13 +31,13 @@ from sglang.srt.layers.communicator import enable_moe_dense_fully_dp
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2MLP as Ernie4MLP
......@@ -361,7 +361,7 @@ class Ernie4_5_ForCausalLM(nn.Module):
class Ernie4_5_MoeForCausalLM(Ernie4_5_ForCausalLM):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
......
......@@ -39,7 +39,6 @@ from sglang.srt.layers.communicator import (
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
is_dp_attention_enabled,
)
from sglang.srt.layers.layernorm import RMSNorm
......@@ -51,9 +50,10 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz,
......@@ -76,10 +76,7 @@ from sglang.srt.models.deepseek_v2 import (
DeepseekV2Model,
DeepseekV2MoE,
)
from sglang.srt.two_batch_overlap import (
MaybeTboDeepEPDispatcher,
model_forward_maybe_tbo,
)
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
from sglang.srt.utils import (
BumpAllocator,
LazyValue,
......@@ -414,19 +411,15 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
)
self.topk = (
TopK(
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
renormalize=config.norm_topk_prob,
use_grouped_topk=True,
num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
)
if not should_use_flashinfer_trtllm_moe()
else None
self.topk = TopK(
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
renormalize=config.norm_topk_prob,
use_grouped_topk=True,
num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
)
self.experts = get_moe_impl_class()(
......@@ -441,31 +434,6 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
quant_config=quant_config,
routed_scaling_factor=self.routed_scaling_factor,
prefix=add_prefix("experts", prefix),
**(
dict(deepep_mode=global_server_args_dict["deepep_mode"])
if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {}
),
# Additional args for FusedMoE
**(
dict(
enable_flashinfer_cutlass_moe=True,
)
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {}
),
**(
dict(
renormalize=config.norm_topk_prob,
use_grouped_topk=True,
num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
)
if should_use_flashinfer_trtllm_moe()
else {}
),
)
self.shared_experts_is_int8 = False
......@@ -496,7 +464,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
self.top_k = config.num_experts_per_tok
if global_server_args_dict["moe_a2a_backend"].is_deepep():
if get_moe_a2a_backend().is_deepep():
# TODO: we will support tp < ep in the future
self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = (
......@@ -520,12 +488,12 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
num_local_experts=config.n_routed_experts // self.tp_size,
hidden_size=config.hidden_size,
params_dtype=config.torch_dtype,
deepep_mode=global_server_args_dict["deepep_mode"],
deepep_mode=get_deepep_mode(),
async_finish=True,
return_recv_hook=True,
)
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
def forward_normal_dual_stream(
self,
......@@ -542,10 +510,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states}
if self.topk is not None:
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
else:
kwargs["router_logits"] = router_logits
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(**kwargs)
if not _is_cuda:
final_hidden_states *= self.routed_scaling_factor
......@@ -588,10 +553,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states}
if self.topk is not None:
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
else:
kwargs["router_logits"] = router_logits
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(**kwargs)
if not _is_cuda and not _use_aiter:
# fused in biased_grouped_topk so we can skip here
......@@ -761,8 +723,6 @@ class Glm4MoeModel(DeepseekV2Model):
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.dp_size = get_local_attention_dp_size()
class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
......@@ -789,7 +749,6 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
self.dp_size = get_local_attention_dp_size()
self._routed_experts_weights_of_layer = LazyValue(
lambda: {
......@@ -953,7 +912,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
......
......@@ -8,19 +8,11 @@ from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
parallel_state,
tensor_model_parallel_all_reduce,
)
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
......@@ -49,7 +41,6 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
config.moe_layer_freq = 1
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.dp_size = get_local_attention_dp_size()
self.quant_config = quant_config
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
self.num_fused_shared_experts = (
......@@ -232,7 +223,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
......
......@@ -40,7 +40,6 @@ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
is_dp_attention_enabled,
)
from sglang.srt.layers.layernorm import RMSNorm
......@@ -50,9 +49,10 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_moe_a2a_backend
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
from sglang.srt.layers.radix_attention import RadixAttention
......@@ -110,16 +110,13 @@ class GptOssSparseMoeBlock(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.layer_id = layer_id
self.activation = config.hidden_act
self.activation_alpha = getattr(config, "hidden_act_alpha", 1.702)
self.swiglu_limit = config.swiglu_limit
self.gemm1_alpha = getattr(config, "hidden_act_alpha", 1.702)
self.gemm1_clamp_limit = config.swiglu_limit
if global_server_args_dict["enable_flashinfer_mxfp4_moe"]:
self.topk = None
else:
self.topk = TopK(
top_k=config.num_experts_per_tok,
renormalize=True,
)
self.topk = TopK(
top_k=config.num_experts_per_tok,
renormalize=True,
)
self.top_k = config.num_experts_per_tok
experts_type = get_moe_impl_class()
......@@ -129,11 +126,9 @@ class GptOssSparseMoeBlock(nn.Module):
quant_config.get_name() if quant_config is not None else None
)
extra_kwargs = {
"enable_flashinfer_cutlass_moe": global_server_args_dict[
"enable_flashinfer_cutlass_moe"
],
# for moe gate_up_proj and down_proj and their bias loading
"use_weight_loader_fused": quant_config_name != "mxfp4",
"use_weight_loader_fused": quant_config_name
!= "mxfp4"
}
self.experts = experts_type(
num_experts=config.num_local_experts
......@@ -144,15 +139,10 @@ class GptOssSparseMoeBlock(nn.Module):
intermediate_size=config.intermediate_size,
quant_config=quant_config,
activation=self.activation,
activation_alpha=self.activation_alpha,
swiglu_limit=self.swiglu_limit,
gemm1_alpha=self.gemm1_alpha,
gemm1_clamp_limit=self.gemm1_clamp_limit,
with_bias=True,
prefix=add_prefix("experts", prefix),
**(
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {}
),
**extra_kwargs,
)
......@@ -171,7 +161,7 @@ class GptOssSparseMoeBlock(nn.Module):
forward_batch: Optional[ForwardBatch] = None,
should_allreduce_fusion: bool = False,
) -> torch.Tensor:
if not global_server_args_dict["moe_a2a_backend"].is_deepep():
if not get_moe_a2a_backend().is_deepep():
return self.forward_normal(hidden_states, should_allreduce_fusion)
else:
raise Exception("forward_deepep branch not implemented yet")
......@@ -189,17 +179,10 @@ class GptOssSparseMoeBlock(nn.Module):
should_allreduce_fusion: bool = False,
) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.router(hidden_states)
kwargs = {"hidden_states": hidden_states}
if self.topk is not None:
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
else:
kwargs["topk_output"] = (self.top_k, router_logits)
final_hidden_states = self.experts(**kwargs)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, topk_output)
if self.tp_size > 1 and not should_allreduce_fusion:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
......@@ -436,7 +419,6 @@ class GptOssDecoderLayer(nn.Module):
self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()
self.local_dp_size = get_local_attention_dp_size()
# GptOss all layers are sparse and have no nextn now
self.is_layer_sparse = True
......@@ -1060,7 +1042,7 @@ class GptOssForCausalLM(nn.Module):
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping_fused(
expert_params_mapping = FusedMoE.make_expert_params_mapping_fused(
ckpt_gate_up_proj_name="gate_up_proj",
ckpt_down_proj_name="down_proj",
ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
......
......@@ -76,7 +76,6 @@ class GraniteMoeMoE(nn.Module):
params_dtype=params_dtype,
reduce_results=True,
quant_config=quant_config,
tp_size=tp_size,
prefix=f"{prefix}.experts",
)
......
......@@ -135,7 +135,6 @@ class Grok1MoE(nn.Module):
intermediate_size=intermediate_size,
params_dtype=params_dtype,
quant_config=quant_config,
tp_size=tp_size,
activation="gelu",
**kwargs,
)
......
......@@ -6,6 +6,7 @@ from transformers import PretrainedConfig
from sglang.srt.distributed import parallel_state
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
......@@ -254,7 +255,7 @@ class InternS1ForConditionalGeneration(nn.Module):
]
expert_params_mapping = []
if "Qwen3MoeForCausalLM" in self.config.text_config.architectures:
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
......
......@@ -12,7 +12,7 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo
from sglang.srt.distributed import parallel_state
from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
......@@ -616,7 +616,7 @@ class InternVLChatModel(nn.Module):
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
......
......@@ -31,7 +31,6 @@ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
is_dp_attention_enabled,
)
from sglang.srt.layers.layernorm import RMSNorm
......@@ -364,7 +363,6 @@ class Llama4DecoderLayer(nn.Module):
rope_theta = config.rope_theta
rope_scaling = config.rope_scaling
max_position_embeddings = config.max_position_embeddings
self.local_dp_size = get_local_attention_dp_size()
self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()
......
......@@ -37,7 +37,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, is_cuda
......
......@@ -47,7 +47,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, make_layers
......@@ -104,7 +103,6 @@ class MixtralMoE(nn.Module):
intermediate_size=intermediate_size,
params_dtype=params_dtype,
quant_config=quant_config,
tp_size=tp_size,
prefix=add_prefix("experts", prefix),
)
......
......@@ -89,7 +89,6 @@ class OlmoeMoE(nn.Module):
intermediate_size=intermediate_size,
reduce_results=True,
quant_config=quant_config,
tp_size=tp_size,
layer_id=layer_id,
prefix=add_prefix("experts", prefix),
)
......
......@@ -17,8 +17,6 @@
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
import logging
from dataclasses import dataclass
from enum import Enum, auto
from typing import Any, Dict, Iterable, Optional, Tuple, Union
import torch
......@@ -31,10 +29,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.eplb.expert_distribution import (
ExpertDistributionRecorder,
get_global_expert_distribution_recorder,
)
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.communicator import (
......@@ -45,7 +40,6 @@ from sglang.srt.layers.communicator import (
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
is_dp_attention_enabled,
)
from sglang.srt.layers.layernorm import RMSNorm
......@@ -55,8 +49,8 @@ from sglang.srt.layers.linear import (
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.moe.ep_moe.layer import EPMoE, get_moe_impl_class
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
......@@ -149,14 +143,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
intermediate_size=config.moe_intermediate_size,
quant_config=quant_config,
prefix=add_prefix("experts", prefix),
# Additional args for FusedMoE
**(
dict(
enable_flashinfer_cutlass_moe=True,
)
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {}
),
)
self.gate = ReplicatedLinear(
......@@ -340,7 +326,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()
self.local_dp_size = get_local_attention_dp_size()
# Qwen2MoE all layers are sparse and have no nextn now
self.is_layer_sparse = True
......
......@@ -28,50 +28,35 @@ from sglang.srt.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
parallel_state,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
)
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_moe_a2a_backend
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
PPProxyTensors,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty
Qwen3MoeConfig = None
......@@ -112,19 +97,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
intermediate_size=config.moe_intermediate_size,
quant_config=quant_config,
prefix=add_prefix("experts", prefix),
**(
dict(deepep_mode=global_server_args_dict["deepep_mode"])
if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {}
),
# Additional args for FusedMoE
**(
dict(
enable_flashinfer_cutlass_moe=True,
)
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {}
),
)
self.gate = ReplicatedLinear(
......@@ -135,7 +107,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
prefix=add_prefix("gate", prefix),
)
if global_server_args_dict["moe_a2a_backend"].is_deepep():
if get_moe_a2a_backend().is_deepep():
# TODO: we will support tp < ep in the future
self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = (
......@@ -150,7 +122,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
use_reduce_scatter: bool = False,
) -> torch.Tensor:
if not global_server_args_dict["moe_a2a_backend"].is_deepep():
if not get_moe_a2a_backend().is_deepep():
return self.forward_normal(hidden_states, use_reduce_scatter)
else:
return self.forward_deepep(hidden_states, forward_batch)
......@@ -491,7 +463,6 @@ class Qwen3MoeDecoderLayer(nn.Module):
self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()
self.local_dp_size = get_local_attention_dp_size()
# Qwen3MoE all layers are sparse and have no nextn now
self.is_layer_sparse = True
......@@ -778,7 +749,7 @@ class Qwen3MoeForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
......
......@@ -38,6 +38,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_moe_a2a_backend
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import TopK
......@@ -150,7 +151,7 @@ class Step3TextMoEMLP(nn.Module):
prefix=add_prefix("gate", prefix),
)
if global_server_args_dict["moe_a2a_backend"].is_deepep():
if get_moe_a2a_backend().is_deepep():
raise NotImplementedError("DeepEP MoE is not supported yet in Step3 model.")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......
......@@ -33,7 +33,9 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
......@@ -121,6 +123,7 @@ class XverseMoE(nn.Module):
]
)
self.pack_params()
self.moe_runner_config = MoeRunnerConfig(inplace=True)
self.router = ReplicatedLinear(
config.hidden_size,
......@@ -129,6 +132,10 @@ class XverseMoE(nn.Module):
quant_config=None,
prefix=add_prefix("router", prefix),
)
self.topk = TopK(
top_k=self.top_k,
renormalize=getattr(self.config, "norm_topk_prob", False),
)
if config.num_shared_experts is not None:
intermediate_size = config.intermediate_size * config.num_shared_experts
......@@ -167,14 +174,13 @@ class XverseMoE(nn.Module):
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.router(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = fused_moe(
hidden_states,
self.w1,
self.w2,
router_logits,
self.top_k,
renormalize=getattr(self.config, "norm_topk_prob", False),
inplace=True,
topk_output,
self.moe_runner_config,
)
if self.config.num_shared_experts is not None:
......
......@@ -37,6 +37,7 @@ from sglang.srt.utils import (
is_hip,
is_port_available,
is_remote_url,
is_triton_kernels_available,
is_valid_ipv6_address,
nullable_str,
)
......@@ -175,9 +176,15 @@ class ServerArgs:
# Expert parallelism
ep_size: int = 1
moe_a2a_backend: Optional[Literal["deepep"]] = None
enable_flashinfer_cutlass_moe: bool = False
enable_flashinfer_trtllm_moe: bool = False
moe_a2a_backend: Literal["none", "deepep"] = "none"
moe_runner_backend: Literal[
"auto",
"triton",
"triton_kernel",
"flashinfer_trtllm",
"flashinfer_cutlass",
"flashinfer_mxfp4",
] = "auto"
enable_flashinfer_allreduce_fusion: bool = False
deepep_mode: Literal["auto", "normal", "low_latency"] = "auto"
ep_num_redundant_experts: int = 0
......@@ -250,8 +257,6 @@ class ServerArgs:
disable_chunked_prefix_cache: bool = False
disable_fast_image_processor: bool = False
enable_return_hidden_states: bool = False
enable_triton_kernel_moe: bool = False
enable_flashinfer_mxfp4_moe: bool = False
scheduler_recv_interval: int = 1
# Debug tensor dumps
......@@ -282,6 +287,9 @@ class ServerArgs:
# Deprecated arguments
enable_ep_moe: bool = False
enable_deepep_moe: bool = False
enable_flashinfer_cutlass_moe: bool = False
enable_flashinfer_trtllm_moe: bool = False
enable_triton_kernel_moe: bool = False
def __post_init__(self):
# Check deprecated arguments
......@@ -298,6 +306,21 @@ class ServerArgs:
print_deprecated_warning(
"NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead."
)
if self.enable_triton_kernel_moe:
self.moe_runner_backend = "triton_kernel"
print_deprecated_warning(
"NOTE: --enable-triton-kernel-moe is deprecated. Please set `--moe-runner-backend` to 'triton_kernel' instead."
)
if self.enable_flashinfer_cutlass_moe:
self.moe_runner_backend = "flashinfer_cutlass"
print_deprecated_warning(
"NOTE: --enable-flashinfer-cutlass-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutlass' instead."
)
if self.enable_flashinfer_trtllm_moe:
self.moe_runner_backend = "flashinfer_trtllm"
print_deprecated_warning(
"NOTE: --enable-flashinfer-trtllm-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_trtllm' instead."
)
# Set missing default values
if self.tokenizer_path is None:
......@@ -517,7 +540,7 @@ class ServerArgs:
), "Please enable dp attention when setting enable_dp_lm_head. "
# MoE kernel
if self.enable_flashinfer_cutlass_moe:
if self.moe_runner_backend == "flashinfer_cutlass":
assert (
self.quantization == "modelopt_fp4"
), "modelopt_fp4 quantization is required for Flashinfer MOE"
......@@ -527,7 +550,7 @@ class ServerArgs:
self.tp_size,
], "The expert parallel size must be 1 or the same as the tensor parallel size"
if self.enable_flashinfer_trtllm_moe:
if self.moe_runner_backend == "flashinfer_trtllm":
if not self.disable_shared_experts_fusion:
self.disable_shared_experts_fusion = True
logger.warning(
......@@ -556,7 +579,7 @@ class ServerArgs:
self.ep_dispatch_algorithm = "static"
if self.enable_eplb:
assert self.ep_size > 1 or self.moe_a2a_backend is not None
assert self.ep_size > 1
if self.enable_expert_distribution_metrics and (
self.expert_distribution_recorder_mode is None
......@@ -1446,19 +1469,22 @@ class ServerArgs:
parser.add_argument(
"--moe-a2a-backend",
type=str,
choices=["deepep"],
choices=["none", "deepep"],
default=ServerArgs.moe_a2a_backend,
help="Choose the backend for MoE A2A.",
)
parser.add_argument(
"--enable-flashinfer-cutlass-moe",
action="store_true",
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
)
parser.add_argument(
"--enable-flashinfer-trtllm-moe",
action="store_true",
help="Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP",
"--moe-runner-backend",
type=str,
choices=[
"auto",
"triton",
"triton_kernel",
"flashinfer_trtllm",
"flashinfer_cutlass",
],
default=ServerArgs.moe_runner_backend,
help="Choose the runner backend for MoE.",
)
parser.add_argument(
"--enable-flashinfer-allreduce-fusion",
......@@ -1825,11 +1851,6 @@ class ServerArgs:
action="store_true",
help="Enable returning hidden states with responses.",
)
parser.add_argument(
"--enable-triton-kernel-moe",
action="store_true",
help="Use triton moe grouped gemm kernel.",
)
parser.add_argument(
"--enable-flashinfer-mxfp4-moe",
action="store_true",
......@@ -1965,6 +1986,21 @@ class ServerArgs:
action="store_true",
help="(Deprecated) Enabling DeepEP MoE implementation for EP MoE.",
)
parser.add_argument(
"--enable-flashinfer-cutlass-moe",
action="store_true",
help="(Deprecated) Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
)
parser.add_argument(
"--enable-flashinfer-trtllm-moe",
action="store_true",
help="(Deprecated) Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP",
)
parser.add_argument(
"--enable-triton-kernel-moe",
action="store_true",
help="(Deprecated) Use triton moe grouped gemm kernel.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
......@@ -2143,18 +2179,21 @@ class ServerArgs:
)
if is_sm100_supported() and is_mxfp4_quant_format:
self.enable_flashinfer_mxfp4_moe = True
self.enable_triton_kernel_moe = False
self.moe_runner_backend = "flashinfer_mxfp4"
logger.warning(
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
)
else:
if self.enable_triton_kernel_moe:
if self.moe_runner_backend == "triton_kernel":
assert (
self.ep_size == 1
), "Triton kernel MoE is only supported when ep_size == 1"
if not self.enable_triton_kernel_moe and self.ep_size == 1:
self.enable_triton_kernel_moe = True
if (
self.moe_runner_backend == "auto"
and self.ep_size == 1
and is_triton_kernels_available()
):
self.moe_runner_backend = "triton_kernel"
logger.warning(
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
)
......
......@@ -14,8 +14,13 @@ from sglang.srt.layers.communicator import (
CommunicateSummableTensorPairFn,
ScatterMode,
)
from sglang.srt.layers.moe import (
get_deepep_mode,
get_moe_a2a_backend,
get_tbo_token_distribution_threshold,
is_tbo_enabled,
)
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import (
......@@ -83,7 +88,7 @@ def _is_two_chunk_split_enabled(extend_lens: Sequence[int]) -> bool:
vanilla_split_seq_index = _split_array_by_balanced_sum(extend_lens)
left_sum = sum(extend_lens[:vanilla_split_seq_index])
overall_sum = sum(extend_lens)
threshold = global_server_args_dict["tbo_token_distribution_threshold"]
threshold = get_tbo_token_distribution_threshold()
assert threshold <= 0.5, f"{threshold=}"
return left_sum < overall_sum * threshold or left_sum > overall_sum * (
1 - threshold
......@@ -299,7 +304,7 @@ class TboCudaGraphRunnerPlugin:
self._tbo_children_num_token_non_padded = torch.zeros((2,), dtype=torch.int32)
def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
if not global_server_args_dict["enable_two_batch_overlap"]:
if not is_tbo_enabled():
return
token_num_per_seq = get_token_num_per_seq(
forward_mode=batch.forward_mode, spec_info=batch.spec_info
......@@ -353,10 +358,12 @@ class TboDPAttentionPreparer:
def prepare_all_gather(
self,
local_batch: ScheduleBatch,
deepep_mode: DeepEPMode,
enable_deepep_moe: bool,
enable_two_batch_overlap: bool,
):
deepep_mode = get_deepep_mode()
enable_deepep_moe = get_moe_a2a_backend().is_deepep()
enable_two_batch_overlap = is_tbo_enabled()
self.enable_two_batch_overlap = enable_two_batch_overlap
if local_batch is not None:
......@@ -384,7 +391,7 @@ class TboDPAttentionPreparer:
and not local_batch.forward_mode.is_target_verify()
)
and enable_deepep_moe
and (resolved_deepep_mode == DeepEPMode.LOW_LATENCY)
and (resolved_deepep_mode.is_low_latency())
)
else:
self.local_tbo_split_seq_index = 0
......@@ -657,6 +664,7 @@ class TboForwardBatchPreparer:
"req_to_token_pool",
"token_to_kv_pool",
"can_run_dp_cuda_graph",
"dp_padding_mode",
"global_forward_mode",
"spec_algorithm",
"capture_hidden_mode",
......@@ -701,7 +709,6 @@ class TboForwardBatchPreparer:
tbo_children=None,
global_num_tokens_gpu=None,
global_num_tokens_cpu=None,
dp_padding_mode=None,
global_dp_buffer_len=global_dp_buffer_len,
global_num_tokens_for_logprob_gpu=None,
global_num_tokens_for_logprob_cpu=None,
......@@ -955,9 +962,7 @@ def _model_forward_tbo_merge_outputs(output_a, output_b):
class MaybeTboDeepEPDispatcher:
def __init__(self, **kwargs):
num_inner_dispatchers = (
2 if global_server_args_dict["enable_two_batch_overlap"] else 1
)
num_inner_dispatchers = 2 if is_tbo_enabled() else 1
self._inners = [
DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
]
......
......@@ -2413,7 +2413,7 @@ def require_mlp_tp_gather(server_args):
return True
elif not server_args.enable_dp_lm_head:
return True
elif server_args.moe_a2a_backend is None:
elif server_args.moe_a2a_backend == "none":
return True
else:
return (
......@@ -2429,7 +2429,7 @@ def require_attn_tp_gather(server_args):
Check if the input of attention is scattered.
"""
assert server_args.moe_dense_tp_size in [1, None]
if server_args.moe_a2a_backend is not None or server_args.moe_dense_tp_size == 1:
if server_args.moe_a2a_backend != "none" or server_args.moe_dense_tp_size == 1:
if server_args.enable_dp_attention:
return server_args.dp_size < server_args.tp_size
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment