"tutorials/vscode:/vscode.git/clone" did not exist on "7fe6d0c85732d57a95cd2260fce1a2e1fd93489c"
Unverified Commit fc0e3b91 authored by lukec's avatar lukec Committed by GitHub
Browse files

Support qwen3 deepep (#6120)

parent d71f3f0a
...@@ -607,7 +607,10 @@ class Qwen2MoeModel(nn.Module): ...@@ -607,7 +607,10 @@ class Qwen2MoeModel(nn.Module):
) )
else: else:
if hidden_states.shape[0] != 0: if hidden_states.shape[0] != 0:
hidden_states, _ = self.norm(hidden_states, residual) if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
......
...@@ -32,6 +32,7 @@ from sglang.srt.distributed import ( ...@@ -32,6 +32,7 @@ from sglang.srt.distributed import (
get_pp_group, get_pp_group,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
parallel_state,
split_tensor_along_last_dim, split_tensor_along_last_dim,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
...@@ -54,8 +55,10 @@ from sglang.srt.layers.linear import ( ...@@ -54,8 +55,10 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.moe.ep_moe.layer import EPMoE from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope
...@@ -65,11 +68,15 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -65,11 +68,15 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
PPProxyTensors,
)
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
from sglang.srt.models.qwen2_moe import Qwen2MoeModel from sglang.srt.models.qwen2_moe import Qwen2MoeModel
from sglang.srt.utils import add_prefix from sglang.srt.utils import DeepEPMode, add_prefix
Qwen3MoeConfig = None Qwen3MoeConfig = None
...@@ -92,7 +99,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -92,7 +99,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
f"the number of experts {config.num_experts}." f"the number of experts {config.num_experts}."
) )
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE MoEImpl = (
DeepEPMoE
if global_server_args_dict["enable_deepep_moe"]
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
)
self.experts = MoEImpl( self.experts = MoEImpl(
num_experts=config.num_experts, num_experts=config.num_experts,
...@@ -102,6 +113,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -102,6 +113,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
renormalize=config.norm_topk_prob, renormalize=config.norm_topk_prob,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("experts", prefix), prefix=add_prefix("experts", prefix),
**(
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
if global_server_args_dict["enable_deepep_moe"]
else {}
),
) )
self.gate = ReplicatedLinear( self.gate = ReplicatedLinear(
...@@ -112,7 +128,37 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -112,7 +128,37 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
prefix=add_prefix("gate", prefix), prefix=add_prefix("gate", prefix),
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if global_server_args_dict["enable_deepep_moe"]:
# TODO: we will support tp < ep in the future
self.ep_size = get_tensor_model_parallel_world_size()
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.renormalize = config.norm_topk_prob
self.deepep_dispatcher = DeepEPDispatcher(
group=parallel_state.get_tp_group().device_group,
router_topk=self.top_k,
permute_fusion=True,
num_experts=config.num_experts,
num_local_experts=config.num_experts // self.tp_size,
hidden_size=config.hidden_size,
params_dtype=config.torch_dtype,
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
async_finish=True, # TODO
return_recv_hook=True,
)
def forward(
self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None
) -> torch.Tensor:
if not global_server_args_dict["enable_deepep_moe"]:
return self.forward_normal(hidden_states)
else:
return self.forward_deepep(hidden_states, forward_mode)
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
...@@ -126,6 +172,68 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -126,6 +172,68 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
return final_hidden_states.view(num_tokens, hidden_dim) return final_hidden_states.view(num_tokens, hidden_dim)
def forward_deepep(
self, hidden_states: torch.Tensor, forward_mode: ForwardMode
) -> torch.Tensor:
if (
forward_mode is not None
and not forward_mode.is_idle()
and hidden_states.shape[0] > 0
):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
topk_weights, topk_idx = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=self.renormalize,
)
else:
topk_idx = torch.full(
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
)
topk_weights = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
)
if self.ep_size > 1:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
(
hidden_states,
topk_idx,
topk_weights,
reorder_topk_ids,
num_recv_tokens_per_expert,
seg_indptr,
masked_m,
expected_m,
) = self.deepep_dispatcher.dispatch(
hidden_states,
topk_idx,
topk_weights,
forward_mode=forward_mode,
)
final_hidden_states = self.experts(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
reorder_topk_ids=reorder_topk_ids,
seg_indptr=seg_indptr,
masked_m=masked_m,
expected_m=expected_m,
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
forward_mode=forward_mode,
)
if self.ep_size > 1:
final_hidden_states = self.deepep_dispatcher.combine(
final_hidden_states,
topk_idx,
topk_weights,
forward_mode,
)
return final_hidden_states
class Qwen3MoeAttention(nn.Module): class Qwen3MoeAttention(nn.Module):
def __init__( def __init__(
...@@ -403,7 +511,7 @@ class Qwen3MoeDecoderLayer(nn.Module): ...@@ -403,7 +511,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
) )
# Fully Connected # Fully Connected
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
# TODO: use reduce-scatter in MLP to avoid this scatter # TODO: use reduce-scatter in MLP to avoid this scatter
# Scatter # Scatter
...@@ -577,7 +685,13 @@ class Qwen3MoeForCausalLM(nn.Module): ...@@ -577,7 +685,13 @@ class Qwen3MoeForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
MoEImpl = (
DeepEPMoE
if global_server_args_dict["enable_deepep_moe"]
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
)
expert_params_mapping = MoEImpl.make_expert_params_mapping( expert_params_mapping = MoEImpl.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment