"git@developer.sourcefind.cn:OpenDAS/torch-sparce.git" did not exist on "39d463b858c799feb2d8a051729014bb95f69d8c"
Unverified Commit 1344ebc8 authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

support qwen3-next-fp8 deepep (#10622)

parent e07b21ce
...@@ -25,12 +25,14 @@ from torch import nn ...@@ -25,12 +25,14 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_pp_group, get_pp_group,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.communicator import ( from sglang.srt.layers.communicator import (
LayerCommunicator, LayerCommunicator,
...@@ -50,6 +52,7 @@ from sglang.srt.layers.linear import ( ...@@ -50,6 +52,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_moe_a2a_backend
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
...@@ -82,6 +85,8 @@ class Qwen2MoeMLP(nn.Module): ...@@ -82,6 +85,8 @@ class Qwen2MoeMLP(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True, reduce_results: bool = True,
prefix: str = "", prefix: str = "",
tp_rank: Optional[int] = None,
tp_size: Optional[int] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
...@@ -90,6 +95,8 @@ class Qwen2MoeMLP(nn.Module): ...@@ -90,6 +95,8 @@ class Qwen2MoeMLP(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix), prefix=add_prefix("gate_up_proj", prefix),
tp_rank=tp_rank,
tp_size=tp_size,
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
intermediate_size, intermediate_size,
...@@ -98,6 +105,8 @@ class Qwen2MoeMLP(nn.Module): ...@@ -98,6 +105,8 @@ class Qwen2MoeMLP(nn.Module):
quant_config=quant_config, quant_config=quant_config,
reduce_results=reduce_results, reduce_results=reduce_results,
prefix=add_prefix("down_proj", prefix), prefix=add_prefix("down_proj", prefix),
tp_rank=tp_rank,
tp_size=tp_size,
) )
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError( raise ValueError(
...@@ -146,7 +155,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -146,7 +155,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
self.experts = get_moe_impl_class(quant_config)( self.experts = get_moe_impl_class(quant_config)(
layer_id=self.layer_id, layer_id=self.layer_id,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
num_experts=config.num_experts, num_experts=config.num_experts
+ global_server_args_dict["ep_num_redundant_experts"],
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
quant_config=quant_config, quant_config=quant_config,
...@@ -168,11 +178,31 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -168,11 +178,31 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
quant_config=quant_config, quant_config=quant_config,
reduce_results=False, reduce_results=False,
prefix=add_prefix("shared_expert", prefix), prefix=add_prefix("shared_expert", prefix),
**(
dict(tp_rank=0, tp_size=1)
if get_moe_a2a_backend().is_deepep()
else {}
),
) )
else: else:
self.shared_expert = None self.shared_expert = None
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
if get_moe_a2a_backend().is_deepep():
# TODO: we will support tp < ep in the future
self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = (
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
)
self.top_k = config.num_experts_per_tok
def get_moe_weights(self):
return [
x.data
for name, x in self.experts.named_parameters()
if name not in ["correction_bias"]
]
def _forward_shared_experts(self, hidden_states: torch.Tensor): def _forward_shared_experts(self, hidden_states: torch.Tensor):
shared_output = None shared_output = None
if self.shared_expert is not None: if self.shared_expert is not None:
...@@ -183,6 +213,36 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -183,6 +213,36 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
) )
return shared_output return shared_output
def _forward_deepep(self, hidden_states: torch.Tensor, forward_batch: ForwardBatch):
shared_output = None
if hidden_states.shape[0] > 0:
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
shared_output = self._forward_shared_experts(hidden_states)
topk_weights, topk_idx, _ = self.topk(
hidden_states,
router_logits,
num_token_non_padded=forward_batch.num_token_non_padded,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
),
)
else:
topk_weights, topk_idx, _ = self.topk.empty_topk_output(
hidden_states.device
)
final_hidden_states = self.experts(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_batch=forward_batch,
)
if shared_output is not None:
final_hidden_states.add_(shared_output)
return final_hidden_states
def _forward_router_experts(self, hidden_states: torch.Tensor): def _forward_router_experts(self, hidden_states: torch.Tensor):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
...@@ -213,6 +273,9 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -213,6 +273,9 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
if get_moe_a2a_backend().is_deepep():
return self._forward_deepep(hidden_states, forward_batch)
DUAL_STREAM_TOKEN_THRESHOLD = 1024 DUAL_STREAM_TOKEN_THRESHOLD = 1024
if ( if (
self.alt_stream is not None self.alt_stream is not None
......
...@@ -13,6 +13,7 @@ from sglang.srt.distributed import ( ...@@ -13,6 +13,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.layers.attention.fla.layernorm_gated import RMSNorm as RMSNormGated from sglang.srt.layers.attention.fla.layernorm_gated import RMSNorm as RMSNormGated
from sglang.srt.layers.attention.mamba.mamba import mamba_v2_sharded_weight_loader from sglang.srt.layers.attention.mamba.mamba import mamba_v2_sharded_weight_loader
...@@ -46,7 +47,14 @@ from sglang.srt.model_loader.weight_utils import ( ...@@ -46,7 +47,14 @@ from sglang.srt.model_loader.weight_utils import (
sharded_weight_loader, sharded_weight_loader,
) )
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock
from sglang.srt.utils import add_prefix, is_cuda, is_npu, make_layers, set_weight_attrs from sglang.srt.utils import (
LazyValue,
add_prefix,
is_cuda,
is_npu,
make_layers,
set_weight_attrs,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -849,13 +857,14 @@ class Qwen3NextModel(nn.Module): ...@@ -849,13 +857,14 @@ class Qwen3NextModel(nn.Module):
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( with get_global_expert_distribution_recorder().with_current_layer(i):
layer_id=i, hidden_states, residual = layer(
positions=positions, layer_id=i,
hidden_states=hidden_states, positions=positions,
residual=residual, hidden_states=hidden_states,
forward_batch=forward_batch, residual=residual,
) forward_batch=forward_batch,
)
if not forward_batch.forward_mode.is_idle(): if not forward_batch.forward_mode.is_idle():
if residual is None: if residual is None:
...@@ -901,6 +910,18 @@ class Qwen3NextForCausalLM(nn.Module): ...@@ -901,6 +910,18 @@ class Qwen3NextForCausalLM(nn.Module):
self.lm_head = self.lm_head.float() self.lm_head = self.lm_head.float()
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self._routed_experts_weights_of_layer = LazyValue(
lambda: {
layer_id: layer.mlp.get_moe_weights()
for layer_id, layer in enumerate(self.model.layers)
if isinstance(layer.mlp, Qwen2MoeSparseMoeBlock)
}
)
@property
def routed_experts_weights_of_layer(self):
return self._routed_experts_weights_of_layer.value
@torch.no_grad() @torch.no_grad()
def forward( def forward(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment