Unverified Commit 29589512 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[6/N] MoE Refactor: Cleanup MoE-related configs (#8849)

parent 584e1ab2
...@@ -50,7 +50,6 @@ from sglang.srt.layers.communicator import ( ...@@ -50,7 +50,6 @@ from sglang.srt.layers.communicator import (
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
get_local_attention_dp_size,
is_dp_attention_enabled, is_dp_attention_enabled,
) )
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
...@@ -61,9 +60,10 @@ from sglang.srt.layers.linear import ( ...@@ -61,9 +60,10 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
...@@ -336,30 +336,6 @@ class DeepseekV2MoE(nn.Module): ...@@ -336,30 +336,6 @@ class DeepseekV2MoE(nn.Module):
quant_config=quant_config, quant_config=quant_config,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
prefix=add_prefix("experts", prefix), prefix=add_prefix("experts", prefix),
**(
dict(deepep_mode=global_server_args_dict["deepep_mode"])
if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {}
),
# Additional args for FusedMoE
**(
dict(
enable_flashinfer_cutlass_moe=True,
)
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {}
),
**(
dict(
renormalize=config.norm_topk_prob,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
)
if should_use_flashinfer_trtllm_moe()
else {}
),
) )
self.shared_experts_is_int8 = False self.shared_experts_is_int8 = False
...@@ -377,7 +353,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -377,7 +353,7 @@ class DeepseekV2MoE(nn.Module):
prefix=add_prefix("shared_experts", prefix), prefix=add_prefix("shared_experts", prefix),
**( **(
dict(tp_rank=0, tp_size=1) dict(tp_rank=0, tp_size=1)
if global_server_args_dict["moe_a2a_backend"].is_deepep() if get_moe_a2a_backend().is_deepep()
else {} else {}
), ),
) )
...@@ -407,7 +383,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -407,7 +383,7 @@ class DeepseekV2MoE(nn.Module):
self.top_k = config.num_experts_per_tok self.top_k = config.num_experts_per_tok
if global_server_args_dict["moe_a2a_backend"].is_deepep(): if get_moe_a2a_backend().is_deepep():
# TODO: we will support tp < ep in the future # TODO: we will support tp < ep in the future
self.ep_size = get_moe_expert_parallel_world_size() self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = ( self.num_experts = (
...@@ -431,12 +407,12 @@ class DeepseekV2MoE(nn.Module): ...@@ -431,12 +407,12 @@ class DeepseekV2MoE(nn.Module):
num_local_experts=config.n_routed_experts // self.tp_size, num_local_experts=config.n_routed_experts // self.tp_size,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
params_dtype=config.torch_dtype, params_dtype=config.torch_dtype,
deepep_mode=global_server_args_dict["deepep_mode"], deepep_mode=get_deepep_mode(),
async_finish=True, async_finish=True,
return_recv_hook=True, return_recv_hook=True,
) )
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep() self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
def get_moe_weights(self): def get_moe_weights(self):
return [ return [
...@@ -484,12 +460,6 @@ class DeepseekV2MoE(nn.Module): ...@@ -484,12 +460,6 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states} kwargs = {"hidden_states": hidden_states}
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
if should_use_flashinfer_trtllm_moe():
kwargs["topk_output"] = (self.topk, router_logits)
else:
kwargs["topk_output"] = self.topk(hidden_states, router_logits) kwargs["topk_output"] = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(**kwargs) final_hidden_states = self.experts(**kwargs)
...@@ -520,12 +490,6 @@ class DeepseekV2MoE(nn.Module): ...@@ -520,12 +490,6 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states} kwargs = {"hidden_states": hidden_states}
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
if should_use_flashinfer_trtllm_moe():
kwargs["topk_output"] = (self.topk, router_logits)
else:
kwargs["topk_output"] = self.topk(hidden_states, router_logits) kwargs["topk_output"] = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(**kwargs) final_hidden_states = self.experts(**kwargs)
...@@ -2478,18 +2442,16 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2478,18 +2442,16 @@ class DeepseekV2ForCausalLM(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts, num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
) )
if self.quant_config and self.quant_config.get_name() == "w4afp8": if self.quant_config and self.quant_config.get_name() == "w4afp8":
expert_params_mapping += ( expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping(
get_moe_impl_class().make_expert_input_scale_params_mapping(
num_experts=self.config.n_routed_experts num_experts=self.config.n_routed_experts
) )
)
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and ( fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
......
...@@ -31,13 +31,13 @@ from sglang.srt.layers.communicator import enable_moe_dense_fully_dp ...@@ -31,13 +31,13 @@ from sglang.srt.layers.communicator import enable_moe_dense_fully_dp
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2MLP as Ernie4MLP from sglang.srt.models.deepseek_v2 import DeepseekV2MLP as Ernie4MLP
...@@ -361,7 +361,7 @@ class Ernie4_5_ForCausalLM(nn.Module): ...@@ -361,7 +361,7 @@ class Ernie4_5_ForCausalLM(nn.Module):
class Ernie4_5_MoeForCausalLM(Ernie4_5_ForCausalLM): class Ernie4_5_MoeForCausalLM(Ernie4_5_ForCausalLM):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
......
...@@ -39,7 +39,6 @@ from sglang.srt.layers.communicator import ( ...@@ -39,7 +39,6 @@ from sglang.srt.layers.communicator import (
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
get_local_attention_dp_size,
is_dp_attention_enabled, is_dp_attention_enabled,
) )
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
...@@ -51,9 +50,10 @@ from sglang.srt.layers.linear import ( ...@@ -51,9 +50,10 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz, is_fp8_fnuz,
...@@ -76,10 +76,7 @@ from sglang.srt.models.deepseek_v2 import ( ...@@ -76,10 +76,7 @@ from sglang.srt.models.deepseek_v2 import (
DeepseekV2Model, DeepseekV2Model,
DeepseekV2MoE, DeepseekV2MoE,
) )
from sglang.srt.two_batch_overlap import ( from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
MaybeTboDeepEPDispatcher,
model_forward_maybe_tbo,
)
from sglang.srt.utils import ( from sglang.srt.utils import (
BumpAllocator, BumpAllocator,
LazyValue, LazyValue,
...@@ -414,8 +411,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -414,8 +411,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
) )
self.topk = ( self.topk = TopK(
TopK(
top_k=config.num_experts_per_tok + self.num_fused_shared_experts, top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
renormalize=config.norm_topk_prob, renormalize=config.norm_topk_prob,
use_grouped_topk=True, use_grouped_topk=True,
...@@ -425,9 +421,6 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -425,9 +421,6 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
correction_bias=self.gate.e_score_correction_bias, correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
) )
if not should_use_flashinfer_trtllm_moe()
else None
)
self.experts = get_moe_impl_class()( self.experts = get_moe_impl_class()(
num_experts=config.n_routed_experts num_experts=config.n_routed_experts
...@@ -441,31 +434,6 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -441,31 +434,6 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
quant_config=quant_config, quant_config=quant_config,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
prefix=add_prefix("experts", prefix), prefix=add_prefix("experts", prefix),
**(
dict(deepep_mode=global_server_args_dict["deepep_mode"])
if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {}
),
# Additional args for FusedMoE
**(
dict(
enable_flashinfer_cutlass_moe=True,
)
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {}
),
**(
dict(
renormalize=config.norm_topk_prob,
use_grouped_topk=True,
num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
)
if should_use_flashinfer_trtllm_moe()
else {}
),
) )
self.shared_experts_is_int8 = False self.shared_experts_is_int8 = False
...@@ -496,7 +464,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -496,7 +464,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
self.top_k = config.num_experts_per_tok self.top_k = config.num_experts_per_tok
if global_server_args_dict["moe_a2a_backend"].is_deepep(): if get_moe_a2a_backend().is_deepep():
# TODO: we will support tp < ep in the future # TODO: we will support tp < ep in the future
self.ep_size = get_moe_expert_parallel_world_size() self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = ( self.num_experts = (
...@@ -520,12 +488,12 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -520,12 +488,12 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
num_local_experts=config.n_routed_experts // self.tp_size, num_local_experts=config.n_routed_experts // self.tp_size,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
params_dtype=config.torch_dtype, params_dtype=config.torch_dtype,
deepep_mode=global_server_args_dict["deepep_mode"], deepep_mode=get_deepep_mode(),
async_finish=True, async_finish=True,
return_recv_hook=True, return_recv_hook=True,
) )
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep() self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
def forward_normal_dual_stream( def forward_normal_dual_stream(
self, self,
...@@ -542,10 +510,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -542,10 +510,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states} kwargs = {"hidden_states": hidden_states}
if self.topk is not None:
kwargs["topk_output"] = self.topk(hidden_states, router_logits) kwargs["topk_output"] = self.topk(hidden_states, router_logits)
else:
kwargs["router_logits"] = router_logits
final_hidden_states = self.experts(**kwargs) final_hidden_states = self.experts(**kwargs)
if not _is_cuda: if not _is_cuda:
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
...@@ -588,10 +553,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -588,10 +553,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states} kwargs = {"hidden_states": hidden_states}
if self.topk is not None:
kwargs["topk_output"] = self.topk(hidden_states, router_logits) kwargs["topk_output"] = self.topk(hidden_states, router_logits)
else:
kwargs["router_logits"] = router_logits
final_hidden_states = self.experts(**kwargs) final_hidden_states = self.experts(**kwargs)
if not _is_cuda and not _use_aiter: if not _is_cuda and not _use_aiter:
# fused in biased_grouped_topk so we can skip here # fused in biased_grouped_topk so we can skip here
...@@ -761,8 +723,6 @@ class Glm4MoeModel(DeepseekV2Model): ...@@ -761,8 +723,6 @@ class Glm4MoeModel(DeepseekV2Model):
) )
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.dp_size = get_local_attention_dp_size()
class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
...@@ -789,7 +749,6 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): ...@@ -789,7 +749,6 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.dp_size = get_local_attention_dp_size()
self._routed_experts_weights_of_layer = LazyValue( self._routed_experts_weights_of_layer = LazyValue(
lambda: { lambda: {
...@@ -953,7 +912,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): ...@@ -953,7 +912,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
......
...@@ -8,19 +8,11 @@ from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig ...@@ -8,19 +8,11 @@ from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_moe_expert_parallel_world_size, get_moe_expert_parallel_world_size,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
parallel_state,
tensor_model_parallel_all_reduce,
) )
from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
)
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
...@@ -49,7 +41,6 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): ...@@ -49,7 +41,6 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
config.moe_layer_freq = 1 config.moe_layer_freq = 1
self.config = config self.config = config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.dp_size = get_local_attention_dp_size()
self.quant_config = quant_config self.quant_config = quant_config
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM") self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
self.num_fused_shared_experts = ( self.num_fused_shared_experts = (
...@@ -232,7 +223,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): ...@@ -232,7 +223,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
......
...@@ -40,7 +40,6 @@ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes ...@@ -40,7 +40,6 @@ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
get_local_attention_dp_size,
is_dp_attention_enabled, is_dp_attention_enabled,
) )
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
...@@ -50,9 +49,10 @@ from sglang.srt.layers.linear import ( ...@@ -50,9 +49,10 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_moe_a2a_backend
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4 from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
...@@ -110,12 +110,9 @@ class GptOssSparseMoeBlock(nn.Module): ...@@ -110,12 +110,9 @@ class GptOssSparseMoeBlock(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.layer_id = layer_id self.layer_id = layer_id
self.activation = config.hidden_act self.activation = config.hidden_act
self.activation_alpha = getattr(config, "hidden_act_alpha", 1.702) self.gemm1_alpha = getattr(config, "hidden_act_alpha", 1.702)
self.swiglu_limit = config.swiglu_limit self.gemm1_clamp_limit = config.swiglu_limit
if global_server_args_dict["enable_flashinfer_mxfp4_moe"]:
self.topk = None
else:
self.topk = TopK( self.topk = TopK(
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
renormalize=True, renormalize=True,
...@@ -129,11 +126,9 @@ class GptOssSparseMoeBlock(nn.Module): ...@@ -129,11 +126,9 @@ class GptOssSparseMoeBlock(nn.Module):
quant_config.get_name() if quant_config is not None else None quant_config.get_name() if quant_config is not None else None
) )
extra_kwargs = { extra_kwargs = {
"enable_flashinfer_cutlass_moe": global_server_args_dict[
"enable_flashinfer_cutlass_moe"
],
# for moe gate_up_proj and down_proj and their bias loading # for moe gate_up_proj and down_proj and their bias loading
"use_weight_loader_fused": quant_config_name != "mxfp4", "use_weight_loader_fused": quant_config_name
!= "mxfp4"
} }
self.experts = experts_type( self.experts = experts_type(
num_experts=config.num_local_experts num_experts=config.num_local_experts
...@@ -144,15 +139,10 @@ class GptOssSparseMoeBlock(nn.Module): ...@@ -144,15 +139,10 @@ class GptOssSparseMoeBlock(nn.Module):
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
quant_config=quant_config, quant_config=quant_config,
activation=self.activation, activation=self.activation,
activation_alpha=self.activation_alpha, gemm1_alpha=self.gemm1_alpha,
swiglu_limit=self.swiglu_limit, gemm1_clamp_limit=self.gemm1_clamp_limit,
with_bias=True, with_bias=True,
prefix=add_prefix("experts", prefix), prefix=add_prefix("experts", prefix),
**(
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {}
),
**extra_kwargs, **extra_kwargs,
) )
...@@ -171,7 +161,7 @@ class GptOssSparseMoeBlock(nn.Module): ...@@ -171,7 +161,7 @@ class GptOssSparseMoeBlock(nn.Module):
forward_batch: Optional[ForwardBatch] = None, forward_batch: Optional[ForwardBatch] = None,
should_allreduce_fusion: bool = False, should_allreduce_fusion: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if not global_server_args_dict["moe_a2a_backend"].is_deepep(): if not get_moe_a2a_backend().is_deepep():
return self.forward_normal(hidden_states, should_allreduce_fusion) return self.forward_normal(hidden_states, should_allreduce_fusion)
else: else:
raise Exception("forward_deepep branch not implemented yet") raise Exception("forward_deepep branch not implemented yet")
...@@ -189,17 +179,10 @@ class GptOssSparseMoeBlock(nn.Module): ...@@ -189,17 +179,10 @@ class GptOssSparseMoeBlock(nn.Module):
should_allreduce_fusion: bool = False, should_allreduce_fusion: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.router(hidden_states) router_logits, _ = self.router(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
kwargs = {"hidden_states": hidden_states} final_hidden_states = self.experts(hidden_states, topk_output)
if self.topk is not None:
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
else:
kwargs["topk_output"] = (self.top_k, router_logits)
final_hidden_states = self.experts(**kwargs)
if self.tp_size > 1 and not should_allreduce_fusion: if self.tp_size > 1 and not should_allreduce_fusion:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
...@@ -436,7 +419,6 @@ class GptOssDecoderLayer(nn.Module): ...@@ -436,7 +419,6 @@ class GptOssDecoderLayer(nn.Module):
self.attn_tp_size = get_attention_tp_size() self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank() self.attn_tp_rank = get_attention_tp_rank()
self.local_dp_size = get_local_attention_dp_size()
# GptOss all layers are sparse and have no nextn now # GptOss all layers are sparse and have no nextn now
self.is_layer_sparse = True self.is_layer_sparse = True
...@@ -1060,7 +1042,7 @@ class GptOssForCausalLM(nn.Module): ...@@ -1060,7 +1042,7 @@ class GptOssForCausalLM(nn.Module):
("qkv_proj", "k_proj", "k"), ("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] ]
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping_fused( expert_params_mapping = FusedMoE.make_expert_params_mapping_fused(
ckpt_gate_up_proj_name="gate_up_proj", ckpt_gate_up_proj_name="gate_up_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_gate_up_proj_bias_name="gate_up_proj_bias", ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
......
...@@ -76,7 +76,6 @@ class GraniteMoeMoE(nn.Module): ...@@ -76,7 +76,6 @@ class GraniteMoeMoE(nn.Module):
params_dtype=params_dtype, params_dtype=params_dtype,
reduce_results=True, reduce_results=True,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
) )
......
...@@ -135,7 +135,6 @@ class Grok1MoE(nn.Module): ...@@ -135,7 +135,6 @@ class Grok1MoE(nn.Module):
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size,
activation="gelu", activation="gelu",
**kwargs, **kwargs,
) )
......
...@@ -6,6 +6,7 @@ from transformers import PretrainedConfig ...@@ -6,6 +6,7 @@ from transformers import PretrainedConfig
from sglang.srt.distributed import parallel_state from sglang.srt.distributed import parallel_state
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.mm_utils import ( from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
...@@ -254,7 +255,7 @@ class InternS1ForConditionalGeneration(nn.Module): ...@@ -254,7 +255,7 @@ class InternS1ForConditionalGeneration(nn.Module):
] ]
expert_params_mapping = [] expert_params_mapping = []
if "Qwen3MoeForCausalLM" in self.config.text_config.architectures: if "Qwen3MoeForCausalLM" in self.config.text_config.architectures:
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
......
...@@ -12,7 +12,7 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo ...@@ -12,7 +12,7 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo
from sglang.srt.distributed import parallel_state from sglang.srt.distributed import parallel_state
from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.mm_utils import ( from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
...@@ -616,7 +616,7 @@ class InternVLChatModel(nn.Module): ...@@ -616,7 +616,7 @@ class InternVLChatModel(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
......
...@@ -31,7 +31,6 @@ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes ...@@ -31,7 +31,6 @@ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
get_local_attention_dp_size,
is_dp_attention_enabled, is_dp_attention_enabled,
) )
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
...@@ -364,7 +363,6 @@ class Llama4DecoderLayer(nn.Module): ...@@ -364,7 +363,6 @@ class Llama4DecoderLayer(nn.Module):
rope_theta = config.rope_theta rope_theta = config.rope_theta
rope_scaling = config.rope_scaling rope_scaling = config.rope_scaling
max_position_embeddings = config.max_position_embeddings max_position_embeddings = config.max_position_embeddings
self.local_dp_size = get_local_attention_dp_size()
self.attn_tp_size = get_attention_tp_size() self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank() self.attn_tp_rank = get_attention_tp_rank()
......
...@@ -37,7 +37,6 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -37,7 +37,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, is_cuda from sglang.srt.utils import add_prefix, is_cuda
......
...@@ -47,7 +47,6 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -47,7 +47,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, make_layers from sglang.srt.utils import add_prefix, make_layers
...@@ -104,7 +103,6 @@ class MixtralMoE(nn.Module): ...@@ -104,7 +103,6 @@ class MixtralMoE(nn.Module):
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size,
prefix=add_prefix("experts", prefix), prefix=add_prefix("experts", prefix),
) )
......
...@@ -89,7 +89,6 @@ class OlmoeMoE(nn.Module): ...@@ -89,7 +89,6 @@ class OlmoeMoE(nn.Module):
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
reduce_results=True, reduce_results=True,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("experts", prefix), prefix=add_prefix("experts", prefix),
) )
......
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
"""Inference-only Qwen2MoE model compatible with HuggingFace weights.""" """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
import logging import logging
from dataclasses import dataclass
from enum import Enum, auto
from typing import Any, Dict, Iterable, Optional, Tuple, Union from typing import Any, Dict, Iterable, Optional, Tuple, Union
import torch import torch
...@@ -31,10 +29,7 @@ from sglang.srt.distributed import ( ...@@ -31,10 +29,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.eplb.expert_distribution import ( from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
ExpertDistributionRecorder,
get_global_expert_distribution_recorder,
)
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.communicator import ( from sglang.srt.layers.communicator import (
...@@ -45,7 +40,6 @@ from sglang.srt.layers.communicator import ( ...@@ -45,7 +40,6 @@ from sglang.srt.layers.communicator import (
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
get_local_attention_dp_size,
is_dp_attention_enabled, is_dp_attention_enabled,
) )
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
...@@ -55,8 +49,8 @@ from sglang.srt.layers.linear import ( ...@@ -55,8 +49,8 @@ from sglang.srt.layers.linear import (
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import EPMoE, get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
...@@ -149,14 +143,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -149,14 +143,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("experts", prefix), prefix=add_prefix("experts", prefix),
# Additional args for FusedMoE
**(
dict(
enable_flashinfer_cutlass_moe=True,
)
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {}
),
) )
self.gate = ReplicatedLinear( self.gate = ReplicatedLinear(
...@@ -340,7 +326,6 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -340,7 +326,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
self.attn_tp_size = get_attention_tp_size() self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank() self.attn_tp_rank = get_attention_tp_rank()
self.local_dp_size = get_local_attention_dp_size()
# Qwen2MoE all layers are sparse and have no nextn now # Qwen2MoE all layers are sparse and have no nextn now
self.is_layer_sparse = True self.is_layer_sparse = True
......
...@@ -28,50 +28,35 @@ from sglang.srt.distributed import ( ...@@ -28,50 +28,35 @@ from sglang.srt.distributed import (
get_pp_group, get_pp_group,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
parallel_state,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
)
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_moe_a2a_backend
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import get_layer_id from sglang.srt.layers.utils import get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
ForwardBatch,
ForwardMode,
PPProxyTensors,
)
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
from sglang.srt.models.qwen2_moe import Qwen2MoeModel from sglang.srt.models.qwen2_moe import Qwen2MoeModel
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty
Qwen3MoeConfig = None Qwen3MoeConfig = None
...@@ -112,19 +97,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -112,19 +97,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("experts", prefix), prefix=add_prefix("experts", prefix),
**(
dict(deepep_mode=global_server_args_dict["deepep_mode"])
if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {}
),
# Additional args for FusedMoE
**(
dict(
enable_flashinfer_cutlass_moe=True,
)
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {}
),
) )
self.gate = ReplicatedLinear( self.gate = ReplicatedLinear(
...@@ -135,7 +107,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -135,7 +107,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
prefix=add_prefix("gate", prefix), prefix=add_prefix("gate", prefix),
) )
if global_server_args_dict["moe_a2a_backend"].is_deepep(): if get_moe_a2a_backend().is_deepep():
# TODO: we will support tp < ep in the future # TODO: we will support tp < ep in the future
self.ep_size = get_moe_expert_parallel_world_size() self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = ( self.num_experts = (
...@@ -150,7 +122,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -150,7 +122,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
use_reduce_scatter: bool = False, use_reduce_scatter: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if not global_server_args_dict["moe_a2a_backend"].is_deepep(): if not get_moe_a2a_backend().is_deepep():
return self.forward_normal(hidden_states, use_reduce_scatter) return self.forward_normal(hidden_states, use_reduce_scatter)
else: else:
return self.forward_deepep(hidden_states, forward_batch) return self.forward_deepep(hidden_states, forward_batch)
...@@ -491,7 +463,6 @@ class Qwen3MoeDecoderLayer(nn.Module): ...@@ -491,7 +463,6 @@ class Qwen3MoeDecoderLayer(nn.Module):
self.attn_tp_size = get_attention_tp_size() self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank() self.attn_tp_rank = get_attention_tp_rank()
self.local_dp_size = get_local_attention_dp_size()
# Qwen3MoE all layers are sparse and have no nextn now # Qwen3MoE all layers are sparse and have no nextn now
self.is_layer_sparse = True self.is_layer_sparse = True
...@@ -778,7 +749,7 @@ class Qwen3MoeForCausalLM(nn.Module): ...@@ -778,7 +749,7 @@ class Qwen3MoeForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
......
...@@ -38,6 +38,7 @@ from sglang.srt.layers.linear import ( ...@@ -38,6 +38,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_moe_a2a_backend
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
...@@ -150,7 +151,7 @@ class Step3TextMoEMLP(nn.Module): ...@@ -150,7 +151,7 @@ class Step3TextMoEMLP(nn.Module):
prefix=add_prefix("gate", prefix), prefix=add_prefix("gate", prefix),
) )
if global_server_args_dict["moe_a2a_backend"].is_deepep(): if get_moe_a2a_backend().is_deepep():
raise NotImplementedError("DeepEP MoE is not supported yet in Step3 model.") raise NotImplementedError("DeepEP MoE is not supported yet in Step3 model.")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......
...@@ -33,7 +33,9 @@ from sglang.srt.layers.linear import ( ...@@ -33,7 +33,9 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope
...@@ -121,6 +123,7 @@ class XverseMoE(nn.Module): ...@@ -121,6 +123,7 @@ class XverseMoE(nn.Module):
] ]
) )
self.pack_params() self.pack_params()
self.moe_runner_config = MoeRunnerConfig(inplace=True)
self.router = ReplicatedLinear( self.router = ReplicatedLinear(
config.hidden_size, config.hidden_size,
...@@ -129,6 +132,10 @@ class XverseMoE(nn.Module): ...@@ -129,6 +132,10 @@ class XverseMoE(nn.Module):
quant_config=None, quant_config=None,
prefix=add_prefix("router", prefix), prefix=add_prefix("router", prefix),
) )
self.topk = TopK(
top_k=self.top_k,
renormalize=getattr(self.config, "norm_topk_prob", False),
)
if config.num_shared_experts is not None: if config.num_shared_experts is not None:
intermediate_size = config.intermediate_size * config.num_shared_experts intermediate_size = config.intermediate_size * config.num_shared_experts
...@@ -167,14 +174,13 @@ class XverseMoE(nn.Module): ...@@ -167,14 +174,13 @@ class XverseMoE(nn.Module):
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.router(hidden_states) router_logits, _ = self.router(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = fused_moe( final_hidden_states = fused_moe(
hidden_states, hidden_states,
self.w1, self.w1,
self.w2, self.w2,
router_logits, topk_output,
self.top_k, self.moe_runner_config,
renormalize=getattr(self.config, "norm_topk_prob", False),
inplace=True,
) )
if self.config.num_shared_experts is not None: if self.config.num_shared_experts is not None:
......
...@@ -37,6 +37,7 @@ from sglang.srt.utils import ( ...@@ -37,6 +37,7 @@ from sglang.srt.utils import (
is_hip, is_hip,
is_port_available, is_port_available,
is_remote_url, is_remote_url,
is_triton_kernels_available,
is_valid_ipv6_address, is_valid_ipv6_address,
nullable_str, nullable_str,
) )
...@@ -175,9 +176,15 @@ class ServerArgs: ...@@ -175,9 +176,15 @@ class ServerArgs:
# Expert parallelism # Expert parallelism
ep_size: int = 1 ep_size: int = 1
moe_a2a_backend: Optional[Literal["deepep"]] = None moe_a2a_backend: Literal["none", "deepep"] = "none"
enable_flashinfer_cutlass_moe: bool = False moe_runner_backend: Literal[
enable_flashinfer_trtllm_moe: bool = False "auto",
"triton",
"triton_kernel",
"flashinfer_trtllm",
"flashinfer_cutlass",
"flashinfer_mxfp4",
] = "auto"
enable_flashinfer_allreduce_fusion: bool = False enable_flashinfer_allreduce_fusion: bool = False
deepep_mode: Literal["auto", "normal", "low_latency"] = "auto" deepep_mode: Literal["auto", "normal", "low_latency"] = "auto"
ep_num_redundant_experts: int = 0 ep_num_redundant_experts: int = 0
...@@ -250,8 +257,6 @@ class ServerArgs: ...@@ -250,8 +257,6 @@ class ServerArgs:
disable_chunked_prefix_cache: bool = False disable_chunked_prefix_cache: bool = False
disable_fast_image_processor: bool = False disable_fast_image_processor: bool = False
enable_return_hidden_states: bool = False enable_return_hidden_states: bool = False
enable_triton_kernel_moe: bool = False
enable_flashinfer_mxfp4_moe: bool = False
scheduler_recv_interval: int = 1 scheduler_recv_interval: int = 1
# Debug tensor dumps # Debug tensor dumps
...@@ -282,6 +287,9 @@ class ServerArgs: ...@@ -282,6 +287,9 @@ class ServerArgs:
# Deprecated arguments # Deprecated arguments
enable_ep_moe: bool = False enable_ep_moe: bool = False
enable_deepep_moe: bool = False enable_deepep_moe: bool = False
enable_flashinfer_cutlass_moe: bool = False
enable_flashinfer_trtllm_moe: bool = False
enable_triton_kernel_moe: bool = False
def __post_init__(self): def __post_init__(self):
# Check deprecated arguments # Check deprecated arguments
...@@ -298,6 +306,21 @@ class ServerArgs: ...@@ -298,6 +306,21 @@ class ServerArgs:
print_deprecated_warning( print_deprecated_warning(
"NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead." "NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead."
) )
if self.enable_triton_kernel_moe:
self.moe_runner_backend = "triton_kernel"
print_deprecated_warning(
"NOTE: --enable-triton-kernel-moe is deprecated. Please set `--moe-runner-backend` to 'triton_kernel' instead."
)
if self.enable_flashinfer_cutlass_moe:
self.moe_runner_backend = "flashinfer_cutlass"
print_deprecated_warning(
"NOTE: --enable-flashinfer-cutlass-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutlass' instead."
)
if self.enable_flashinfer_trtllm_moe:
self.moe_runner_backend = "flashinfer_trtllm"
print_deprecated_warning(
"NOTE: --enable-flashinfer-trtllm-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_trtllm' instead."
)
# Set missing default values # Set missing default values
if self.tokenizer_path is None: if self.tokenizer_path is None:
...@@ -517,7 +540,7 @@ class ServerArgs: ...@@ -517,7 +540,7 @@ class ServerArgs:
), "Please enable dp attention when setting enable_dp_lm_head. " ), "Please enable dp attention when setting enable_dp_lm_head. "
# MoE kernel # MoE kernel
if self.enable_flashinfer_cutlass_moe: if self.moe_runner_backend == "flashinfer_cutlass":
assert ( assert (
self.quantization == "modelopt_fp4" self.quantization == "modelopt_fp4"
), "modelopt_fp4 quantization is required for Flashinfer MOE" ), "modelopt_fp4 quantization is required for Flashinfer MOE"
...@@ -527,7 +550,7 @@ class ServerArgs: ...@@ -527,7 +550,7 @@ class ServerArgs:
self.tp_size, self.tp_size,
], "The expert parallel size must be 1 or the same as the tensor parallel size" ], "The expert parallel size must be 1 or the same as the tensor parallel size"
if self.enable_flashinfer_trtllm_moe: if self.moe_runner_backend == "flashinfer_trtllm":
if not self.disable_shared_experts_fusion: if not self.disable_shared_experts_fusion:
self.disable_shared_experts_fusion = True self.disable_shared_experts_fusion = True
logger.warning( logger.warning(
...@@ -556,7 +579,7 @@ class ServerArgs: ...@@ -556,7 +579,7 @@ class ServerArgs:
self.ep_dispatch_algorithm = "static" self.ep_dispatch_algorithm = "static"
if self.enable_eplb: if self.enable_eplb:
assert self.ep_size > 1 or self.moe_a2a_backend is not None assert self.ep_size > 1
if self.enable_expert_distribution_metrics and ( if self.enable_expert_distribution_metrics and (
self.expert_distribution_recorder_mode is None self.expert_distribution_recorder_mode is None
...@@ -1446,19 +1469,22 @@ class ServerArgs: ...@@ -1446,19 +1469,22 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--moe-a2a-backend", "--moe-a2a-backend",
type=str, type=str,
choices=["deepep"], choices=["none", "deepep"],
default=ServerArgs.moe_a2a_backend, default=ServerArgs.moe_a2a_backend,
help="Choose the backend for MoE A2A.", help="Choose the backend for MoE A2A.",
) )
parser.add_argument( parser.add_argument(
"--enable-flashinfer-cutlass-moe", "--moe-runner-backend",
action="store_true", type=str,
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP", choices=[
) "auto",
parser.add_argument( "triton",
"--enable-flashinfer-trtllm-moe", "triton_kernel",
action="store_true", "flashinfer_trtllm",
help="Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP", "flashinfer_cutlass",
],
default=ServerArgs.moe_runner_backend,
help="Choose the runner backend for MoE.",
) )
parser.add_argument( parser.add_argument(
"--enable-flashinfer-allreduce-fusion", "--enable-flashinfer-allreduce-fusion",
...@@ -1825,11 +1851,6 @@ class ServerArgs: ...@@ -1825,11 +1851,6 @@ class ServerArgs:
action="store_true", action="store_true",
help="Enable returning hidden states with responses.", help="Enable returning hidden states with responses.",
) )
parser.add_argument(
"--enable-triton-kernel-moe",
action="store_true",
help="Use triton moe grouped gemm kernel.",
)
parser.add_argument( parser.add_argument(
"--enable-flashinfer-mxfp4-moe", "--enable-flashinfer-mxfp4-moe",
action="store_true", action="store_true",
...@@ -1965,6 +1986,21 @@ class ServerArgs: ...@@ -1965,6 +1986,21 @@ class ServerArgs:
action="store_true", action="store_true",
help="(Deprecated) Enabling DeepEP MoE implementation for EP MoE.", help="(Deprecated) Enabling DeepEP MoE implementation for EP MoE.",
) )
parser.add_argument(
"--enable-flashinfer-cutlass-moe",
action="store_true",
help="(Deprecated) Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
)
parser.add_argument(
"--enable-flashinfer-trtllm-moe",
action="store_true",
help="(Deprecated) Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP",
)
parser.add_argument(
"--enable-triton-kernel-moe",
action="store_true",
help="(Deprecated) Use triton moe grouped gemm kernel.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
...@@ -2143,18 +2179,21 @@ class ServerArgs: ...@@ -2143,18 +2179,21 @@ class ServerArgs:
) )
if is_sm100_supported() and is_mxfp4_quant_format: if is_sm100_supported() and is_mxfp4_quant_format:
self.enable_flashinfer_mxfp4_moe = True self.moe_runner_backend = "flashinfer_mxfp4"
self.enable_triton_kernel_moe = False
logger.warning( logger.warning(
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel." "Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
) )
else: else:
if self.enable_triton_kernel_moe: if self.moe_runner_backend == "triton_kernel":
assert ( assert (
self.ep_size == 1 self.ep_size == 1
), "Triton kernel MoE is only supported when ep_size == 1" ), "Triton kernel MoE is only supported when ep_size == 1"
if not self.enable_triton_kernel_moe and self.ep_size == 1: if (
self.enable_triton_kernel_moe = True self.moe_runner_backend == "auto"
and self.ep_size == 1
and is_triton_kernels_available()
):
self.moe_runner_backend = "triton_kernel"
logger.warning( logger.warning(
"Detected GPT-OSS model, enabling triton_kernels MOE kernel." "Detected GPT-OSS model, enabling triton_kernels MOE kernel."
) )
......
...@@ -14,8 +14,13 @@ from sglang.srt.layers.communicator import ( ...@@ -14,8 +14,13 @@ from sglang.srt.layers.communicator import (
CommunicateSummableTensorPairFn, CommunicateSummableTensorPairFn,
ScatterMode, ScatterMode,
) )
from sglang.srt.layers.moe import (
get_deepep_mode,
get_moe_a2a_backend,
get_tbo_token_distribution_threshold,
is_tbo_enabled,
)
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
...@@ -83,7 +88,7 @@ def _is_two_chunk_split_enabled(extend_lens: Sequence[int]) -> bool: ...@@ -83,7 +88,7 @@ def _is_two_chunk_split_enabled(extend_lens: Sequence[int]) -> bool:
vanilla_split_seq_index = _split_array_by_balanced_sum(extend_lens) vanilla_split_seq_index = _split_array_by_balanced_sum(extend_lens)
left_sum = sum(extend_lens[:vanilla_split_seq_index]) left_sum = sum(extend_lens[:vanilla_split_seq_index])
overall_sum = sum(extend_lens) overall_sum = sum(extend_lens)
threshold = global_server_args_dict["tbo_token_distribution_threshold"] threshold = get_tbo_token_distribution_threshold()
assert threshold <= 0.5, f"{threshold=}" assert threshold <= 0.5, f"{threshold=}"
return left_sum < overall_sum * threshold or left_sum > overall_sum * ( return left_sum < overall_sum * threshold or left_sum > overall_sum * (
1 - threshold 1 - threshold
...@@ -299,7 +304,7 @@ class TboCudaGraphRunnerPlugin: ...@@ -299,7 +304,7 @@ class TboCudaGraphRunnerPlugin:
self._tbo_children_num_token_non_padded = torch.zeros((2,), dtype=torch.int32) self._tbo_children_num_token_non_padded = torch.zeros((2,), dtype=torch.int32)
def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int): def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
if not global_server_args_dict["enable_two_batch_overlap"]: if not is_tbo_enabled():
return return
token_num_per_seq = get_token_num_per_seq( token_num_per_seq = get_token_num_per_seq(
forward_mode=batch.forward_mode, spec_info=batch.spec_info forward_mode=batch.forward_mode, spec_info=batch.spec_info
...@@ -353,10 +358,12 @@ class TboDPAttentionPreparer: ...@@ -353,10 +358,12 @@ class TboDPAttentionPreparer:
def prepare_all_gather( def prepare_all_gather(
self, self,
local_batch: ScheduleBatch, local_batch: ScheduleBatch,
deepep_mode: DeepEPMode,
enable_deepep_moe: bool,
enable_two_batch_overlap: bool,
): ):
deepep_mode = get_deepep_mode()
enable_deepep_moe = get_moe_a2a_backend().is_deepep()
enable_two_batch_overlap = is_tbo_enabled()
self.enable_two_batch_overlap = enable_two_batch_overlap self.enable_two_batch_overlap = enable_two_batch_overlap
if local_batch is not None: if local_batch is not None:
...@@ -384,7 +391,7 @@ class TboDPAttentionPreparer: ...@@ -384,7 +391,7 @@ class TboDPAttentionPreparer:
and not local_batch.forward_mode.is_target_verify() and not local_batch.forward_mode.is_target_verify()
) )
and enable_deepep_moe and enable_deepep_moe
and (resolved_deepep_mode == DeepEPMode.LOW_LATENCY) and (resolved_deepep_mode.is_low_latency())
) )
else: else:
self.local_tbo_split_seq_index = 0 self.local_tbo_split_seq_index = 0
...@@ -657,6 +664,7 @@ class TboForwardBatchPreparer: ...@@ -657,6 +664,7 @@ class TboForwardBatchPreparer:
"req_to_token_pool", "req_to_token_pool",
"token_to_kv_pool", "token_to_kv_pool",
"can_run_dp_cuda_graph", "can_run_dp_cuda_graph",
"dp_padding_mode",
"global_forward_mode", "global_forward_mode",
"spec_algorithm", "spec_algorithm",
"capture_hidden_mode", "capture_hidden_mode",
...@@ -701,7 +709,6 @@ class TboForwardBatchPreparer: ...@@ -701,7 +709,6 @@ class TboForwardBatchPreparer:
tbo_children=None, tbo_children=None,
global_num_tokens_gpu=None, global_num_tokens_gpu=None,
global_num_tokens_cpu=None, global_num_tokens_cpu=None,
dp_padding_mode=None,
global_dp_buffer_len=global_dp_buffer_len, global_dp_buffer_len=global_dp_buffer_len,
global_num_tokens_for_logprob_gpu=None, global_num_tokens_for_logprob_gpu=None,
global_num_tokens_for_logprob_cpu=None, global_num_tokens_for_logprob_cpu=None,
...@@ -955,9 +962,7 @@ def _model_forward_tbo_merge_outputs(output_a, output_b): ...@@ -955,9 +962,7 @@ def _model_forward_tbo_merge_outputs(output_a, output_b):
class MaybeTboDeepEPDispatcher: class MaybeTboDeepEPDispatcher:
def __init__(self, **kwargs): def __init__(self, **kwargs):
num_inner_dispatchers = ( num_inner_dispatchers = 2 if is_tbo_enabled() else 1
2 if global_server_args_dict["enable_two_batch_overlap"] else 1
)
self._inners = [ self._inners = [
DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers) DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
] ]
......
...@@ -2413,7 +2413,7 @@ def require_mlp_tp_gather(server_args): ...@@ -2413,7 +2413,7 @@ def require_mlp_tp_gather(server_args):
return True return True
elif not server_args.enable_dp_lm_head: elif not server_args.enable_dp_lm_head:
return True return True
elif server_args.moe_a2a_backend is None: elif server_args.moe_a2a_backend == "none":
return True return True
else: else:
return ( return (
...@@ -2429,7 +2429,7 @@ def require_attn_tp_gather(server_args): ...@@ -2429,7 +2429,7 @@ def require_attn_tp_gather(server_args):
Check if the input of attention is scattered. Check if the input of attention is scattered.
""" """
assert server_args.moe_dense_tp_size in [1, None] assert server_args.moe_dense_tp_size in [1, None]
if server_args.moe_a2a_backend is not None or server_args.moe_dense_tp_size == 1: if server_args.moe_a2a_backend != "none" or server_args.moe_dense_tp_size == 1:
if server_args.enable_dp_attention: if server_args.enable_dp_attention:
return server_args.dp_size < server_args.tp_size return server_args.dp_size < server_args.tp_size
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment