Unverified Commit 1344ebc8 authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

support qwen3-next-fp8 deepep (#10622)

parent e07b21ce
......@@ -25,12 +25,14 @@ from torch import nn
from transformers import PretrainedConfig
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.communicator import (
LayerCommunicator,
......@@ -50,6 +52,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_moe_a2a_backend
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import TopK
......@@ -82,6 +85,8 @@ class Qwen2MoeMLP(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
tp_rank: Optional[int] = None,
tp_size: Optional[int] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
......@@ -90,6 +95,8 @@ class Qwen2MoeMLP(nn.Module):
bias=False,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
tp_rank=tp_rank,
tp_size=tp_size,
)
self.down_proj = RowParallelLinear(
intermediate_size,
......@@ -98,6 +105,8 @@ class Qwen2MoeMLP(nn.Module):
quant_config=quant_config,
reduce_results=reduce_results,
prefix=add_prefix("down_proj", prefix),
tp_rank=tp_rank,
tp_size=tp_size,
)
if hidden_act != "silu":
raise ValueError(
......@@ -146,7 +155,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
self.experts = get_moe_impl_class(quant_config)(
layer_id=self.layer_id,
top_k=config.num_experts_per_tok,
num_experts=config.num_experts,
num_experts=config.num_experts
+ global_server_args_dict["ep_num_redundant_experts"],
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
quant_config=quant_config,
......@@ -168,11 +178,31 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
quant_config=quant_config,
reduce_results=False,
prefix=add_prefix("shared_expert", prefix),
**(
dict(tp_rank=0, tp_size=1)
if get_moe_a2a_backend().is_deepep()
else {}
),
)
else:
self.shared_expert = None
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
if get_moe_a2a_backend().is_deepep():
# TODO: we will support tp < ep in the future
self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = (
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
)
self.top_k = config.num_experts_per_tok
def get_moe_weights(self):
return [
x.data
for name, x in self.experts.named_parameters()
if name not in ["correction_bias"]
]
def _forward_shared_experts(self, hidden_states: torch.Tensor):
shared_output = None
if self.shared_expert is not None:
......@@ -183,6 +213,36 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
)
return shared_output
def _forward_deepep(self, hidden_states: torch.Tensor, forward_batch: ForwardBatch):
shared_output = None
if hidden_states.shape[0] > 0:
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
shared_output = self._forward_shared_experts(hidden_states)
topk_weights, topk_idx, _ = self.topk(
hidden_states,
router_logits,
num_token_non_padded=forward_batch.num_token_non_padded,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
),
)
else:
topk_weights, topk_idx, _ = self.topk.empty_topk_output(
hidden_states.device
)
final_hidden_states = self.experts(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_batch=forward_batch,
)
if shared_output is not None:
final_hidden_states.add_(shared_output)
return final_hidden_states
def _forward_router_experts(self, hidden_states: torch.Tensor):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
......@@ -213,6 +273,9 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if get_moe_a2a_backend().is_deepep():
return self._forward_deepep(hidden_states, forward_batch)
DUAL_STREAM_TOKEN_THRESHOLD = 1024
if (
self.alt_stream is not None
......
......@@ -13,6 +13,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.layers.attention.fla.layernorm_gated import RMSNorm as RMSNormGated
from sglang.srt.layers.attention.mamba.mamba import mamba_v2_sharded_weight_loader
......@@ -46,7 +47,14 @@ from sglang.srt.model_loader.weight_utils import (
sharded_weight_loader,
)
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock
from sglang.srt.utils import add_prefix, is_cuda, is_npu, make_layers, set_weight_attrs
from sglang.srt.utils import (
LazyValue,
add_prefix,
is_cuda,
is_npu,
make_layers,
set_weight_attrs,
)
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
......@@ -849,13 +857,14 @@ class Qwen3NextModel(nn.Module):
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
layer_id=i,
positions=positions,
hidden_states=hidden_states,
residual=residual,
forward_batch=forward_batch,
)
with get_global_expert_distribution_recorder().with_current_layer(i):
hidden_states, residual = layer(
layer_id=i,
positions=positions,
hidden_states=hidden_states,
residual=residual,
forward_batch=forward_batch,
)
if not forward_batch.forward_mode.is_idle():
if residual is None:
......@@ -901,6 +910,18 @@ class Qwen3NextForCausalLM(nn.Module):
self.lm_head = self.lm_head.float()
self.logits_processor = LogitsProcessor(config)
self._routed_experts_weights_of_layer = LazyValue(
lambda: {
layer_id: layer.mlp.get_moe_weights()
for layer_id, layer in enumerate(self.model.layers)
if isinstance(layer.mlp, Qwen2MoeSparseMoeBlock)
}
)
@property
def routed_experts_weights_of_layer(self):
return self._routed_experts_weights_of_layer.value
@torch.no_grad()
def forward(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment