"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "747f406a43a2547058cf133c2b152db8034fa28c"
Unverified Commit e330f2b8 authored by laixin's avatar laixin Committed by GitHub
Browse files

[qwen3] support qwen3 ep moe (#5917)


Co-authored-by: default avatarsleepcoo <sleepcoo@gmail.com>
parent 3ddf5b9d
...@@ -36,6 +36,7 @@ from sglang.srt.layers.linear import ( ...@@ -36,6 +36,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
...@@ -45,6 +46,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -45,6 +46,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, make_layers from sglang.srt.utils import add_prefix, make_layers
...@@ -108,12 +110,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -108,12 +110,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
f"the number of experts {config.num_experts}." f"the number of experts {config.num_experts}."
) )
self.experts = FusedMoE( MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
self.experts = MoEImpl(
num_experts=config.num_experts, num_experts=config.num_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob, renormalize=config.norm_topk_prob,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("experts", prefix), prefix=add_prefix("experts", prefix),
...@@ -427,7 +430,9 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -427,7 +430,9 @@ class Qwen2MoeForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
expert_params_mapping = FusedMoE.make_expert_params_mapping( MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
expert_params_mapping = MoEImpl.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
......
...@@ -40,6 +40,7 @@ from sglang.srt.layers.linear import ( ...@@ -40,6 +40,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
...@@ -48,6 +49,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -48,6 +49,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
...@@ -73,12 +75,13 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -73,12 +75,13 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
f"the number of experts {config.num_experts}." f"the number of experts {config.num_experts}."
) )
self.experts = FusedMoE( MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
self.experts = MoEImpl(
num_experts=config.num_experts, num_experts=config.num_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob, renormalize=config.norm_topk_prob,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("experts", prefix), prefix=add_prefix("experts", prefix),
...@@ -356,7 +359,9 @@ class Qwen3MoeForCausalLM(nn.Module): ...@@ -356,7 +359,9 @@ class Qwen3MoeForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
expert_params_mapping = FusedMoE.make_expert_params_mapping( MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
expert_params_mapping = MoEImpl.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment